diff --git a/.bazelrc b/.bazelrc index 554440cfe3d..53485cb9743 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1 +1 @@ -build --cxxopt=-std=c++14 --host_cxxopt=-std=c++14 +build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 diff --git a/.gemini/config.yaml b/.gemini/config.yaml new file mode 100644 index 00000000000..71adf793964 --- /dev/null +++ b/.gemini/config.yaml @@ -0,0 +1,13 @@ +have_fun: false +memory_config: + disabled: false +code_review: + disable: false + comment_severity_threshold: MEDIUM + max_review_comments: -1 + pull_request_opened: + help: false + summary: false + code_review: false + include_drafts: false +ignore_patterns: [] diff --git a/.github/workflows/branch-testing.yml b/.github/workflows/branch-testing.yml new file mode 100644 index 00000000000..ece8ec4cd58 --- /dev/null +++ b/.github/workflows/branch-testing.yml @@ -0,0 +1,41 @@ +name: GitHub Actions Branch Testing + +on: + push: + branches: + - master + - 'v1.*' + schedule: + - cron: '54 19 * * SUN' # weekly at a "random" time + +permissions: + contents: read + +jobs: + arm64: + runs-on: ubuntu-24.04-arm + strategy: + matrix: + jre: [17] + fail-fast: false # Should swap to true if we grow a large matrix + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-java@v4 + with: + java-version: ${{ matrix.jre }} + distribution: 'temurin' + + - name: Gradle cache + uses: actions/cache@v4 + with: + path: | + ~/.gradle/caches + ~/.gradle/wrapper + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*', '**/gradle-wrapper.properties') }} + restore-keys: | + ${{ runner.os }}-gradle- + + - name: Build + run: ./gradlew -Dorg.gradle.parallel=true -Dorg.gradle.jvmargs='-Xmx1g' -PskipAndroid=true -PskipCodegen=true -PerrorProne=false test + diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 8c639cf14ed..ccabd9be79f 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - jre: [8, 11, 17] + jre: [8, 11, 17, 21] fail-fast: false # Should swap to true if we grow a large matrix steps: @@ -77,8 +77,11 @@ jobs: bazel: runs-on: ubuntu-latest + strategy: + matrix: + bzlmod: [true, false] env: - USE_BAZEL_VERSION: 6.0.0 + USE_BAZEL_VERSION: 7.7.1 steps: - uses: actions/checkout@v4 @@ -97,19 +100,11 @@ jobs: key: ${{ runner.os }}-bazel-${{ env.USE_BAZEL_VERSION }}-${{ hashFiles('WORKSPACE', 'repositories.bzl') }} - name: Run bazel build - run: bazelisk build //... --enable_bzlmod=false - - - name: Run example bazel build - run: bazelisk build //... --enable_bzlmod=false - working-directory: ./examples + run: bazelisk build //... --enable_bzlmod=${{ matrix.bzlmod }} - - name: Run bazel build (bzlmod) - env: - USE_BAZEL_VERSION: 7.0.0 - run: bazelisk build //... --enable_bzlmod=true + - name: Run bazel test + run: bazelisk test //... --enable_bzlmod=${{ matrix.bzlmod }} - - name: Run example bazel build (bzlmod) - env: - USE_BAZEL_VERSION: 7.0.0 - run: bazelisk build //... --enable_bzlmod=true + - name: Run example bazel build + run: bazelisk build //... --enable_bzlmod=${{ matrix.bzlmod }} working-directory: ./examples diff --git a/.gitignore b/.gitignore index 92a0e3d6d3a..b078d891adf 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,9 @@ MODULE.bazel.lock .gitignore bin +# VsCode +.vscode + # OS X .DS_Store diff --git a/BUILD.bazel b/BUILD.bazel index b6d0838bf87..27a99fb62eb 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@rules_java//java:defs.bzl", "java_library", "java_plugin") load("@rules_jvm_external//:defs.bzl", "artifact") load(":java_grpc_library.bzl", "java_grpc_library") @@ -33,7 +35,6 @@ java_library( "//api", "//protobuf", "//stub", - "//stub:javax_annotation", "@com_google_protobuf//:protobuf_java", artifact("com.google.code.findbugs:jsr305"), artifact("com.google.guava:guava"), @@ -47,7 +48,6 @@ java_library( "//api", "//protobuf-lite", "//stub", - "//stub:javax_annotation", artifact("com.google.code.findbugs:jsr305"), artifact("com.google.guava:guava"), ], @@ -67,6 +67,5 @@ java_library( visibility = ["//:__subpackages__"], exports = [ artifact("com.google.auto.value:auto-value-annotations"), - artifact("org.apache.tomcat:annotations-api"), # @Generated for Java 9+ ], ) diff --git a/COMPILING.md b/COMPILING.md index de3cbb026c1..b7df1319beb 100644 --- a/COMPILING.md +++ b/COMPILING.md @@ -44,11 +44,11 @@ This section is only necessary if you are making changes to the code generation. Most users only need to use `skipCodegen=true` as discussed above. ### Build Protobuf -The codegen plugin is C++ code and requires protobuf 21.7 or later. +The codegen plugin is C++ code and requires protobuf 22.5 or later. For Linux, Mac and MinGW: ``` -$ PROTOBUF_VERSION=21.7 +$ PROTOBUF_VERSION=22.5 $ curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v$PROTOBUF_VERSION/protobuf-all-$PROTOBUF_VERSION.tar.gz $ tar xzf protobuf-all-$PROTOBUF_VERSION.tar.gz $ cd protobuf-$PROTOBUF_VERSION diff --git a/MAINTAINERS.md b/MAINTAINERS.md index f1c07ccd6f2..5048c7c5aca 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -11,8 +11,6 @@ for general contribution guidelines. - [ejona86](https://github.com/ejona86), Google LLC - [jdcormie](https://github.com/jdcormie), Google LLC - [kannanjgithub](https://github.com/kannanjgithub), Google LLC -- [larry-safran](https://github.com/larry-safran), Google LLC -- [markb74](https://github.com/markb74), Google LLC - [ran-su](https://github.com/ran-su), Google LLC - [sergiitk](https://github.com/sergiitk), Google LLC - [temawi](https://github.com/temawi), Google LLC @@ -26,7 +24,9 @@ for general contribution guidelines. - [ericgribkoff](https://github.com/ericgribkoff) - [jiangtaoli2016](https://github.com/jiangtaoli2016) - [jtattermusch](https://github.com/jtattermusch) +- [larry-safran](https://github.com/larry-safran) - [louiscryan](https://github.com/louiscryan) +- [markb74](https://github.com/markb74) - [nicolasnoble](https://github.com/nicolasnoble) - [nmittler](https://github.com/nmittler) - [sanjaypujare](https://github.com/sanjaypujare) diff --git a/MODULE.bazel b/MODULE.bazel index b60ea565073..803eacba297 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -1,81 +1,62 @@ module( name = "grpc-java", + version = "1.81.0-SNAPSHOT", # CURRENT_GRPC_VERSION compatibility_level = 0, repo_name = "io_grpc_grpc_java", - version = "1.68.0-SNAPSHOT", # CURRENT_GRPC_VERSION ) # GRPC_DEPS_START IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "com.google.android:annotations:4.1.1.4", - "com.google.api.grpc:proto-google-common-protos:2.29.0", - "com.google.auth:google-auth-library-credentials:1.23.0", - "com.google.auth:google-auth-library-oauth2-http:1.23.0", + "com.google.api.grpc:proto-google-common-protos:2.64.1", + "com.google.auth:google-auth-library-credentials:1.42.1", + "com.google.auth:google-auth-library-oauth2-http:1.42.1", "com.google.auto.value:auto-value-annotations:1.11.0", "com.google.auto.value:auto-value:1.11.0", "com.google.code.findbugs:jsr305:3.0.2", - "com.google.code.gson:gson:2.11.0", - "com.google.errorprone:error_prone_annotations:2.28.0", + "com.google.code.gson:gson:2.13.2", + "com.google.errorprone:error_prone_annotations:2.48.0", "com.google.guava:failureaccess:1.0.1", - "com.google.guava:guava:33.2.1-android", - "com.google.re2j:re2j:1.7", - "com.google.truth:truth:1.4.2", + "com.google.guava:guava:33.5.0-android", + "com.google.re2j:re2j:1.8", + "com.google.s2a.proto.v2:s2a-proto:0.1.3", + "com.google.truth:truth:1.4.5", "com.squareup.okhttp:okhttp:2.7.5", "com.squareup.okio:okio:2.10.0", # 3.0+ needs swapping to -jvm; need work to avoid flag-day - "io.netty:netty-buffer:4.1.110.Final", - "io.netty:netty-codec-http2:4.1.110.Final", - "io.netty:netty-codec-http:4.1.110.Final", - "io.netty:netty-codec-socks:4.1.110.Final", - "io.netty:netty-codec:4.1.110.Final", - "io.netty:netty-common:4.1.110.Final", - "io.netty:netty-handler-proxy:4.1.110.Final", - "io.netty:netty-handler:4.1.110.Final", - "io.netty:netty-resolver:4.1.110.Final", - "io.netty:netty-tcnative-boringssl-static:2.0.65.Final", - "io.netty:netty-tcnative-classes:2.0.65.Final", - "io.netty:netty-transport-native-epoll:jar:linux-x86_64:4.1.110.Final", - "io.netty:netty-transport-native-unix-common:4.1.110.Final", - "io.netty:netty-transport:4.1.110.Final", + "io.netty:netty-buffer:4.1.132.Final", + "io.netty:netty-codec-http2:4.1.132.Final", + "io.netty:netty-codec-http:4.1.132.Final", + "io.netty:netty-codec-socks:4.1.132.Final", + "io.netty:netty-codec:4.1.132.Final", + "io.netty:netty-common:4.1.132.Final", + "io.netty:netty-handler-proxy:4.1.132.Final", + "io.netty:netty-handler:4.1.132.Final", + "io.netty:netty-resolver:4.1.132.Final", + "io.netty:netty-tcnative-boringssl-static:2.0.75.Final", + "io.netty:netty-tcnative-classes:2.0.75.Final", + "io.netty:netty-transport-native-epoll:jar:linux-x86_64:4.1.132.Final", + "io.netty:netty-transport-native-unix-common:4.1.132.Final", + "io.netty:netty-transport:4.1.132.Final", "io.opencensus:opencensus-api:0.31.0", "io.opencensus:opencensus-contrib-grpc-metrics:0.31.0", "io.perfmark:perfmark-api:0.27.0", "junit:junit:4.13.2", - "org.apache.tomcat:annotations-api:6.0.53", - "org.codehaus.mojo:animal-sniffer-annotations:1.24", + "org.mockito:mockito-core:4.4.0", + "org.checkerframework:checker-qual:3.49.5", + "org.codehaus.mojo:animal-sniffer-annotations:1.27", ] # GRPC_DEPS_END +bazel_dep(name = "bazel_jar_jar", version = "0.1.11.bcr.1") bazel_dep(name = "bazel_skylib", version = "1.7.1") -bazel_dep(name = "googleapis", repo_name = "com_google_googleapis", version = "0.0.0-20240326-1c8d509c5") -# CEL Spec may be removed when cncf/xds MODULE is no longer using protobuf 27.x -bazel_dep(name = "cel-spec", repo_name = "dev_cel", version = "0.15.0") -bazel_dep(name = "grpc", repo_name = "com_github_grpc_grpc", version = "1.56.3.bcr.1") -bazel_dep(name = "grpc-proto", repo_name = "io_grpc_grpc_proto", version = "0.0.0-20240627-ec30f58") -bazel_dep(name = "protobuf", repo_name = "com_google_protobuf", version = "23.1") +bazel_dep(name = "googleapis", version = "0.0.0-20240326-1c8d509c5", repo_name = "com_google_googleapis") +bazel_dep(name = "grpc-proto", version = "0.0.0-20240627-ec30f58.bcr.1", repo_name = "io_grpc_grpc_proto") +bazel_dep(name = "protobuf", version = "33.4", repo_name = "com_google_protobuf") bazel_dep(name = "rules_cc", version = "0.0.9") -bazel_dep(name = "rules_java", version = "5.3.5") -bazel_dep(name = "rules_go", repo_name = "io_bazel_rules_go", version = "0.46.0") +bazel_dep(name = "rules_java", version = "9.1.0") bazel_dep(name = "rules_jvm_external", version = "6.0") -bazel_dep(name = "rules_proto", version = "5.3.0-21.7") - -non_module_deps = use_extension("//:repositories.bzl", "grpc_java_repositories_extension") - -use_repo( - non_module_deps, - "com_github_cncf_xds", - "envoy_api", -) - -grpc_repo_deps_ext = use_extension("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_repo_deps_ext") - -use_repo( - grpc_repo_deps_ext, - "com_envoyproxy_protoc_gen_validate", - "opencensus_proto", -) maven = use_extension("@rules_jvm_external//:extensions.bzl", "maven") - maven.install( artifacts = IO_GRPC_GRPC_JAVA_ARTIFACTS, repositories = [ @@ -83,124 +64,97 @@ maven.install( ], strict_visibility = True, ) - use_repo(maven, "maven") maven.override( coordinates = "com.google.protobuf:protobuf-java", target = "@com_google_protobuf//:protobuf_java", ) - maven.override( coordinates = "com.google.protobuf:protobuf-java-util", target = "@com_google_protobuf//:protobuf_java_util", ) - maven.override( coordinates = "com.google.protobuf:protobuf-javalite", target = "@com_google_protobuf//:protobuf_javalite", ) - maven.override( coordinates = "io.grpc:grpc-alts", target = "@io_grpc_grpc_java//alts", ) - maven.override( coordinates = "io.grpc:grpc-api", target = "@io_grpc_grpc_java//api", ) - maven.override( coordinates = "io.grpc:grpc-auth", target = "@io_grpc_grpc_java//auth", ) - maven.override( coordinates = "io.grpc:grpc-census", target = "@io_grpc_grpc_java//census", ) - maven.override( coordinates = "io.grpc:grpc-context", target = "@io_grpc_grpc_java//context", ) - maven.override( coordinates = "io.grpc:grpc-core", target = "@io_grpc_grpc_java//core:core_maven", ) - maven.override( coordinates = "io.grpc:grpc-googleapis", target = "@io_grpc_grpc_java//googleapis", ) - maven.override( coordinates = "io.grpc:grpc-grpclb", target = "@io_grpc_grpc_java//grpclb", ) - maven.override( coordinates = "io.grpc:grpc-inprocess", target = "@io_grpc_grpc_java//inprocess", ) - maven.override( coordinates = "io.grpc:grpc-netty", target = "@io_grpc_grpc_java//netty", ) - maven.override( coordinates = "io.grpc:grpc-netty-shaded", target = "@io_grpc_grpc_java//netty:shaded_maven", ) - maven.override( coordinates = "io.grpc:grpc-okhttp", target = "@io_grpc_grpc_java//okhttp", ) - maven.override( coordinates = "io.grpc:grpc-protobuf", target = "@io_grpc_grpc_java//protobuf", ) - maven.override( coordinates = "io.grpc:grpc-protobuf-lite", target = "@io_grpc_grpc_java//protobuf-lite", ) - maven.override( coordinates = "io.grpc:grpc-rls", target = "@io_grpc_grpc_java//rls", ) - maven.override( coordinates = "io.grpc:grpc-services", target = "@io_grpc_grpc_java//services:services_maven", ) - maven.override( coordinates = "io.grpc:grpc-stub", target = "@io_grpc_grpc_java//stub", ) - maven.override( coordinates = "io.grpc:grpc-testing", target = "@io_grpc_grpc_java//testing", ) - maven.override( coordinates = "io.grpc:grpc-xds", target = "@io_grpc_grpc_java//xds:xds_maven", ) - maven.override( coordinates = "io.grpc:grpc-util", target = "@io_grpc_grpc_java//util", ) - -switched_rules = use_extension("@com_google_googleapis//:extensions.bzl", "switched_rules") - -switched_rules.use_languages(java = True) diff --git a/README.md b/README.md index cb38ad66394..b0f7a6a14af 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ gRPC-Java - An RPC library and framework Supported Platforms ------------------- -gRPC-Java supports Java 8 and later. Android minSdkVersion 21 (Lollipop) and +gRPC-Java supports Java 8 and later. Android minSdkVersion 23 (Marshmallow) and later are supported with [Java 8 language desugaring][android-java-8]. TLS usage on Android typically requires Play Services Dynamic Security Provider. @@ -44,8 +44,8 @@ For a guided tour, take a look at the [quick start guide](https://grpc.io/docs/languages/java/quickstart) or the more explanatory [gRPC basics](https://grpc.io/docs/languages/java/basics). -The [examples](https://github.com/grpc/grpc-java/tree/v1.66.0/examples) and the -[Android example](https://github.com/grpc/grpc-java/tree/v1.66.0/examples/android) +The [examples](https://github.com/grpc/grpc-java/tree/v1.80.0/examples) and the +[Android example](https://github.com/grpc/grpc-java/tree/v1.80.0/examples/android) are standalone projects that showcase the usage of gRPC. Download @@ -56,42 +56,34 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: io.grpc grpc-netty-shaded - 1.66.0 + 1.80.0 runtime io.grpc grpc-protobuf - 1.66.0 + 1.80.0 io.grpc grpc-stub - 1.66.0 - - - org.apache.tomcat - annotations-api - 6.0.53 - provided + 1.80.0 ``` Or for Gradle with non-Android, add to your dependencies: ```gradle -runtimeOnly 'io.grpc:grpc-netty-shaded:1.66.0' -implementation 'io.grpc:grpc-protobuf:1.66.0' -implementation 'io.grpc:grpc-stub:1.66.0' -compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ +runtimeOnly 'io.grpc:grpc-netty-shaded:1.80.0' +implementation 'io.grpc:grpc-protobuf:1.80.0' +implementation 'io.grpc:grpc-stub:1.80.0' ``` For Android client, use `grpc-okhttp` instead of `grpc-netty-shaded` and `grpc-protobuf-lite` instead of `grpc-protobuf`: ```gradle -implementation 'io.grpc:grpc-okhttp:1.66.0' -implementation 'io.grpc:grpc-protobuf-lite:1.66.0' -implementation 'io.grpc:grpc-stub:1.66.0' -compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ +implementation 'io.grpc:grpc-okhttp:1.80.0' +implementation 'io.grpc:grpc-protobuf-lite:1.80.0' +implementation 'io.grpc:grpc-stub:1.80.0' ``` For [Bazel](https://bazel.build), you can either @@ -99,10 +91,10 @@ For [Bazel](https://bazel.build), you can either (with the GAVs from above), or use `@io_grpc_grpc_java//api` et al (see below). [the JARs]: -https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.66.0 +https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.80.0 Development snapshots are available in [Sonatypes's snapshot -repository](https://oss.sonatype.org/content/repositories/snapshots/). +repository](https://central.sonatype.com/repository/maven-snapshots/). Generated Code -------------- @@ -129,9 +121,9 @@ For protobuf-based codegen integrated with the Maven build system, you can use protobuf-maven-plugin 0.6.1 - com.google.protobuf:protoc:3.25.3:exe:${os.detected.classifier} + com.google.protobuf:protoc:3.25.8:exe:${os.detected.classifier} grpc-java - io.grpc:protoc-gen-grpc-java:1.66.0:exe:${os.detected.classifier} + io.grpc:protoc-gen-grpc-java:1.80.0:exe:${os.detected.classifier} @@ -152,16 +144,16 @@ For non-Android protobuf-based codegen integrated with the Gradle build system, you can use [protobuf-gradle-plugin][]: ```gradle plugins { - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' } protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.25.3" + artifact = "com.google.protobuf:protoc:3.25.8" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.66.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.80.0' } } generateProtoTasks { @@ -185,16 +177,16 @@ use protobuf-gradle-plugin but specify the 'lite' options: ```gradle plugins { - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' } protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.25.3" + artifact = "com.google.protobuf:protoc:3.25.8" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.66.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.80.0' } } generateProtoTasks { diff --git a/RELEASING.md b/RELEASING.md index bb1b77d0557..c57829b8c25 100644 --- a/RELEASING.md +++ b/RELEASING.md @@ -65,7 +65,7 @@ would be used to create all `v1.7` tags (e.g. `v1.7.0`, `v1.7.1`). ```bash git fetch upstream git checkout -b v$MAJOR.$MINOR.x \ - $(git log --pretty=format:%H --grep "^Start $MAJOR.$((MINOR+1)).0 development cycle$" upstream/master)^ + $(git log --pretty=format:%H --grep "^Start $MAJOR.$((MINOR+1)).0 development cycle" upstream/master)^ git push upstream v$MAJOR.$MINOR.x ``` 5. Continue with Google-internal steps at go/grpc-java/releasing, but stop @@ -132,7 +132,9 @@ Tagging the Release compiler/src/test{,Lite}/golden/Test{,Deprecated}Service.java.txt ./gradlew build git commit -a -m "Bump version to $MAJOR.$MINOR.$((PATCH+1))-SNAPSHOT" + git push -u origin release-v$MAJOR.$MINOR.$PATCH ``` + Raise a PR and set the base branch of the PR to v$MAJOR.$MINOR.x of the upstream grpc-java repo. 6. Go through PR review and push the release tag and updated release branch to GitHub (DO NOT click the merge button on the GitHub page): @@ -158,21 +160,21 @@ Tagging the Release repository can then be `released`, which will begin the process of pushing the new artifacts to Maven Central (the staging repository will be destroyed in the process). You can see the complete process for releasing to Maven - Central on the [OSSRH site](https://central.sonatype.org/pages/releasing-the-deployment.html). + Central on the [OSSRH site](https://central.sonatype.org/publish/publish-portal-ossrh-staging-api/#deploying). 10. We have containers for each release to detect compatibility regressions with old releases. Generate one for the new release by following the [GCR image generation instructions][gcr-image]. Summary: ```bash # If you haven't previously configured docker: - gcloud auth configure-docker + gcloud auth configure-docker us-docker.pkg.dev # In main grpc repo, add the new version to matrix ${EDITOR:-nano -w} tools/interop_matrix/client_matrix.py tools/interop_matrix/create_matrix_images.py --git_checkout --release=v$MAJOR.$MINOR.$PATCH \ --upload_images --language java - docker pull gcr.io/grpc-testing/grpc_interop_java:v$MAJOR.$MINOR.$PATCH - docker_image=gcr.io/grpc-testing/grpc_interop_java:v$MAJOR.$MINOR.$PATCH \ + docker pull us-docker.pkg.dev/grpc-testing/testing-images-public/grpc_interop_java:v$MAJOR.$MINOR.$PATCH + docker_image=us-docker.pkg.dev/grpc-testing/testing-images-public/grpc_interop_java:v$MAJOR.$MINOR.$PATCH \ tools/interop_matrix/testcases/java__master # Commit the changes diff --git a/SECURITY.md b/SECURITY.md index 5c5e3598b29..e710ceaabe1 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -330,14 +330,10 @@ is an option](#tls-with-conscrypt). Otherwise you need to [build your own 32-bit version of `netty-tcnative`](https://netty.io/wiki/forked-tomcat-native.html#wiki-h2-6). -If on Alpine Linux and you see "Error loading shared library libcrypt.so.1: No -such file or directory". Run `apk update && apk add libc6-compat` to install the -necessary dependency. - -If on Alpine Linux, try to use `grpc-netty-shaded` instead of `grpc-netty` or -(if you need `grpc-netty`) `netty-tcnative-boringssl-static` instead of -`netty-tcnative`. If those are not an option, you may consider using -[netty-tcnative-alpine](https://github.com/pires/netty-tcnative-alpine). +If on Alpine Linux, depending on your specific JDK you may see a crash in +netty_tcnative. This is generally caused by a missing symbol. Run `apk install +gcompat` and use the environment variable `LD_PRELOAD=/lib/libgcompat.so.0` when +executing Java. If on Fedora 30 or later and you see "libcrypt.so.1: cannot open shared object file: No such file or directory". Run `dnf -y install libxcrypt-compat` to @@ -399,7 +395,12 @@ grpc-netty version | netty-handler version | netty-tcnative-boringssl-static ver 1.57.x-1.58.x | 4.1.93.Final | 2.0.61.Final 1.59.x | 4.1.97.Final | 2.0.61.Final 1.60.x-1.66.x | 4.1.100.Final | 2.0.61.Final -1.67.x | 4.1.110.Final | 2.0.65.Final +1.67.x-1.70.x | 4.1.110.Final | 2.0.65.Final +1.71.x-1.74.x | 4.1.110.Final | 2.0.70.Final +1.75.x-1.76.x | 4.1.124.Final | 2.0.72.Final +1.77.x-1.78.x | 4.1.127.Final | 2.0.74.Final +1.79.x-1.80.x | 4.1.130.Final | 2.0.74.Final +1.81.x- | 4.1.132.Final | 2.0.75.Final _(grpc-netty-shaded avoids issues with keeping these versions in sync.)_ diff --git a/WORKSPACE b/WORKSPACE index 7bbfbcc5fa3..1efdf2793a8 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,43 +1,53 @@ workspace(name = "io_grpc_grpc_java") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("//:repositories.bzl", "IO_GRPC_GRPC_JAVA_ARTIFACTS", "IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS", "grpc_java_repositories") + +grpc_java_repositories() http_archive( name = "rules_java", - url = "https://github.com/bazelbuild/rules_java/releases/download/5.3.5/rules_java-5.3.5.tar.gz", - sha256 = "c73336802d0b4882e40770666ad055212df4ea62cfa6edf9cb0f9d29828a0934", + sha256 = "47632cc506c858011853073449801d648e10483d4b50e080ec2549a4b2398960", + urls = [ + "https://github.com/bazelbuild/rules_java/releases/download/8.15.2/rules_java-8.15.2.tar.gz", + ], ) -http_archive( - name = "rules_jvm_external", - sha256 = "d31e369b854322ca5098ea12c69d7175ded971435e55c18dd9dd5f29cc5249ac", - strip_prefix = "rules_jvm_external-5.3", - url = "https://github.com/bazelbuild/rules_jvm_external/releases/download/5.3/rules_jvm_external-5.3.tar.gz", -) +load("@com_google_protobuf//:protobuf_deps.bzl", "PROTOBUF_MAVEN_ARTIFACTS", "protobuf_deps") -load("@rules_jvm_external//:defs.bzl", "maven_install") -load("//:repositories.bzl", "IO_GRPC_GRPC_JAVA_ARTIFACTS") -load("//:repositories.bzl", "IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS") -load("//:repositories.bzl", "grpc_java_repositories") +protobuf_deps() -grpc_java_repositories() +load("@rules_java//java:rules_java_deps.bzl", "rules_java_dependencies") -load("@com_google_protobuf//:protobuf_deps.bzl", "PROTOBUF_MAVEN_ARTIFACTS") -load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") +rules_java_dependencies() -protobuf_deps() +load("@bazel_features//:deps.bzl", "bazel_features_deps") -load("@envoy_api//bazel:repositories.bzl", "api_dependencies") +bazel_features_deps() -api_dependencies() +load("@bazel_jar_jar//:jar_jar.bzl", "jar_jar_repositories") + +jar_jar_repositories() + +load("@rules_python//python:repositories.bzl", "py_repositories") + +py_repositories() load("@com_google_googleapis//:repository_rules.bzl", "switched_rules_by_language") switched_rules_by_language( name = "com_google_googleapis_imports", - java = True, ) +http_archive( + name = "rules_jvm_external", + sha256 = "d31e369b854322ca5098ea12c69d7175ded971435e55c18dd9dd5f29cc5249ac", + strip_prefix = "rules_jvm_external-5.3", + url = "https://github.com/bazelbuild/rules_jvm_external/releases/download/5.3/rules_jvm_external-5.3.tar.gz", +) + +load("@rules_jvm_external//:defs.bzl", "maven_install") + maven_install( artifacts = IO_GRPC_GRPC_JAVA_ARTIFACTS + PROTOBUF_MAVEN_ARTIFACTS, override_targets = IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS, diff --git a/alts/BUILD.bazel b/alts/BUILD.bazel index 73420e11053..f29df303fbe 100644 --- a/alts/BUILD.bazel +++ b/alts/BUILD.bazel @@ -1,5 +1,7 @@ +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") -load("@rules_proto//proto:defs.bzl", "proto_library") load("//:java_grpc_library.bzl", "java_grpc_library") java_library( @@ -12,12 +14,12 @@ java_library( ":handshaker_java_proto", "//api", "//core:internal", - "//grpclb", "//netty", "//stub", "@com_google_protobuf//:protobuf_java", "@com_google_protobuf//:protobuf_java_util", artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), artifact("io.netty:netty-buffer"), artifact("io.netty:netty-codec"), diff --git a/alts/build.gradle b/alts/build.gradle index 9477e2540af..c206a37bcef 100644 --- a/alts/build.gradle +++ b/alts/build.gradle @@ -2,8 +2,8 @@ plugins { id "java-library" id "maven-publish" - id "com.github.johnrengelman.shadow" id "com.google.protobuf" + id "com.gradleup.shadow" id "ru.vyarus.animalsniffer" } @@ -14,15 +14,12 @@ dependencies { implementation project(':grpc-auth'), project(':grpc-core'), project(":grpc-context"), // Override google-auth dependency with our newer version - project(':grpc-grpclb'), project(':grpc-protobuf'), project(':grpc-stub'), libraries.protobuf.java, libraries.conscrypt, - libraries.guava.jre, // JRE required by protobuf-java-util from grpclb libraries.google.auth.oauth2Http def nettyDependency = implementation project(':grpc-netty') - compileOnly libraries.javax.annotation shadow configurations.implementation.getDependencies().minus(nettyDependency) shadow project(path: ':grpc-netty-shaded', configuration: 'shadow') @@ -44,7 +41,11 @@ dependencies { classifier = "linux-x86_64" } } - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } configureProtoCompilation() diff --git a/alts/src/generated/main/grpc/io/grpc/alts/internal/HandshakerServiceGrpc.java b/alts/src/generated/main/grpc/io/grpc/alts/internal/HandshakerServiceGrpc.java index 2caba4a0544..07e4256eb75 100644 --- a/alts/src/generated/main/grpc/io/grpc/alts/internal/HandshakerServiceGrpc.java +++ b/alts/src/generated/main/grpc/io/grpc/alts/internal/HandshakerServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/gcp/handshaker.proto") @io.grpc.stub.annotations.GrpcGenerated public final class HandshakerServiceGrpc { @@ -60,6 +57,21 @@ public HandshakerServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOption return HandshakerServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static HandshakerServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public HandshakerServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new HandshakerServiceBlockingV2Stub(channel, callOptions); + } + }; + return HandshakerServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -157,6 +169,40 @@ public io.grpc.stub.StreamObserver doHandsh /** * A stub to allow clients to do synchronous rpc calls to service HandshakerService. */ + public static final class HandshakerServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private HandshakerServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected HandshakerServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new HandshakerServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Handshaker service accepts a stream of handshaker request, returning a
+     * stream of handshaker response. Client is expected to send exactly one
+     * message with either client_start or server_start followed by one or more
+     * messages with next. Each time client sends a request, the handshaker
+     * service expects to respond. Client does not have to wait for service's
+     * response before sending next request.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + doHandshake() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getDoHandshakeMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service HandshakerService. + */ public static final class HandshakerServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private HandshakerServiceBlockingStub( diff --git a/alts/src/main/java/io/grpc/alts/AltsContextUtil.java b/alts/src/main/java/io/grpc/alts/AltsContextUtil.java index 91b06756dc3..f45179bbd91 100644 --- a/alts/src/main/java/io/grpc/alts/AltsContextUtil.java +++ b/alts/src/main/java/io/grpc/alts/AltsContextUtil.java @@ -14,9 +14,10 @@ * limitations under the License. */ - package io.grpc.alts; +import io.grpc.Attributes; +import io.grpc.ClientCall; import io.grpc.ExperimentalApi; import io.grpc.ServerCall; import io.grpc.alts.internal.AltsInternalContext; @@ -29,14 +30,36 @@ public final class AltsContextUtil { private AltsContextUtil() {} /** - * Creates a {@link AltsContext} from ALTS context information in the {@link ServerCall}. + * Creates an {@link AltsContext} from ALTS context information in the {@link ServerCall}. * * @param call the {@link ServerCall} containing the ALTS information * @return the created {@link AltsContext} * @throws IllegalArgumentException if the {@link ServerCall} has no ALTS information. */ - public static AltsContext createFrom(ServerCall call) { - Object authContext = call.getAttributes().get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY); + public static AltsContext createFrom(ServerCall call) { + return createFrom(call.getAttributes()); + } + + /** + * Creates an {@link AltsContext} from ALTS context information in the {@link ClientCall}. + * + * @param call the {@link ClientCall} containing the ALTS information + * @return the created {@link AltsContext} + * @throws IllegalArgumentException if the {@link ClientCall} has no ALTS information. + */ + public static AltsContext createFrom(ClientCall call) { + return createFrom(call.getAttributes()); + } + + /** + * Creates an {@link AltsContext} from ALTS context information in the {@link Attributes}. + * + * @param attributes the {@link Attributes} containing the ALTS information + * @return the created {@link AltsContext} + * @throws IllegalArgumentException if the {@link Attributes} has no ALTS information. + */ + public static AltsContext createFrom(Attributes attributes) { + Object authContext = attributes.get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY); if (!(authContext instanceof AltsInternalContext)) { throw new IllegalArgumentException("No ALTS context information found"); } @@ -49,8 +72,28 @@ public static AltsContext createFrom(ServerCall call) { * @param call the {@link ServerCall} to check * @return true, if the {@link ServerCall} contains ALTS information and false otherwise. */ - public static boolean check(ServerCall call) { - Object authContext = call.getAttributes().get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY); + public static boolean check(ServerCall call) { + return check(call.getAttributes()); + } + + /** + * Checks if the {@link ClientCall} contains ALTS information. + * + * @param call the {@link ClientCall} to check + * @return true, if the {@link ClientCall} contains ALTS information and false otherwise. + */ + public static boolean check(ClientCall call) { + return check(call.getAttributes()); + } + + /** + * Checks if the {@link Attributes} contains ALTS information. + * + * @param attributes the {@link Attributes} to check + * @return true, if the {@link Attributes} contains ALTS information and false otherwise. + */ + public static boolean check(Attributes attributes) { + Object authContext = attributes.get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY); return authContext instanceof AltsInternalContext; } } diff --git a/alts/src/main/java/io/grpc/alts/DualCallCredentials.java b/alts/src/main/java/io/grpc/alts/DualCallCredentials.java new file mode 100644 index 00000000000..08104712e65 --- /dev/null +++ b/alts/src/main/java/io/grpc/alts/DualCallCredentials.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.alts; + +import io.grpc.CallCredentials; +import java.util.concurrent.Executor; + +/** + * {@code CallCredentials} that will pick the right credentials based on whether the established + * connection is ALTS or TLS. + */ +final class DualCallCredentials extends CallCredentials { + private final CallCredentials tlsCallCredentials; + private final CallCredentials altsCallCredentials; + + public DualCallCredentials(CallCredentials tlsCallCreds, CallCredentials altsCallCreds) { + tlsCallCredentials = tlsCallCreds; + altsCallCredentials = altsCallCreds; + } + + @Override + public void applyRequestMetadata( + CallCredentials.RequestInfo requestInfo, + Executor appExecutor, + CallCredentials.MetadataApplier applier) { + if (AltsContextUtil.check(requestInfo.getTransportAttrs())) { + altsCallCredentials.applyRequestMetadata(requestInfo, appExecutor, applier); + } else { + tlsCallCredentials.applyRequestMetadata(requestInfo, appExecutor, applier); + } + } +} diff --git a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java index d9c2ddaaed7..1b5880120a4 100644 --- a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java +++ b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java @@ -63,6 +63,7 @@ public static Builder newBuilder() { */ public static final class Builder { private CallCredentials callCredentials; + private CallCredentials altsCallCredentials; private Builder() {} @@ -72,23 +73,32 @@ public Builder callCredentials(CallCredentials callCreds) { return this; } + /** Constructs GoogleDefaultChannelCredentials with an ALTS-specific call credential. */ + public Builder altsCallCredentials(CallCredentials callCreds) { + altsCallCredentials = callCreds; + return this; + } + /** Builds a GoogleDefaultChannelCredentials instance. */ public ChannelCredentials build() { ChannelCredentials nettyCredentials = InternalNettyChannelCredentials.create(createClientFactory()); - if (callCredentials != null) { - return CompositeChannelCredentials.create(nettyCredentials, callCredentials); - } - CallCredentials callCreds; - try { - callCreds = MoreCallCredentials.from(GoogleCredentials.getApplicationDefault()); - } catch (IOException e) { - callCreds = - new FailingCallCredentials( - Status.UNAUTHENTICATED - .withDescription("Failed to get Google default credentials") - .withCause(e)); + CallCredentials tlsCallCreds = callCredentials; + if (tlsCallCreds == null) { + try { + tlsCallCreds = MoreCallCredentials.from(GoogleCredentials.getApplicationDefault()); + } catch (IOException e) { + tlsCallCreds = + new FailingCallCredentials( + Status.UNAUTHENTICATED + .withDescription("Failed to get Google default credentials") + .withCause(e)); + } } + CallCredentials callCreds = + altsCallCredentials == null + ? tlsCallCreds + : new DualCallCredentials(tlsCallCreds, altsCallCredentials); return CompositeChannelCredentials.create(nettyCredentials, callCreds); } diff --git a/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java b/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java index 8e8d175b7af..5e32d22d901 100644 --- a/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java +++ b/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java @@ -21,6 +21,7 @@ import io.grpc.ClientCall; import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourceHolder.Resource; import io.grpc.netty.NettyChannelBuilder; import io.netty.channel.EventLoopGroup; @@ -36,15 +37,37 @@ * application will have at most one connection to the handshaker service. */ final class HandshakerServiceChannel { + // Port 8080 is necessary for ALTS handshake. + private static final int ALTS_PORT = 8080; + private static final String DEFAULT_TARGET = "metadata.google.internal.:8080"; static final Resource SHARED_HANDSHAKER_CHANNEL = - new ChannelResource("metadata.google.internal.:8080"); - + new ChannelResource(getHandshakerTarget(System.getenv("GCE_METADATA_HOST"))); + + /** + * Returns handshaker target. When GCE_METADATA_HOST is provided, it might contain port which we + * will discard and use ALTS_PORT instead. + */ + static String getHandshakerTarget(String envValue) { + if (envValue == null || envValue.isEmpty()) { + return DEFAULT_TARGET; + } + String host = envValue; + int portIndex = host.lastIndexOf(':'); + if (portIndex != -1) { + host = host.substring(0, portIndex); // Discard port if specified + } + return host + ":" + ALTS_PORT; // Utilize ALTS port in all cases + } + /** Returns a resource of handshaker service channel for testing only. */ static Resource getHandshakerChannelForTesting(String handshakerAddress) { return new ChannelResource(handshakerAddress); } + private static final boolean EXPERIMENTAL_ALTS_HANDSHAKER_KEEPALIVE_PARAMS = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_ALTS_HANDSHAKER_KEEPALIVE_PARAMS", false); + private static class ChannelResource implements Resource { private final String target; @@ -57,12 +80,16 @@ public Channel create() { /* Use its own event loop thread pool to avoid blocking. */ EventLoopGroup eventGroup = new NioEventLoopGroup(1, new DefaultThreadFactory("handshaker pool", true)); - ManagedChannel channel = NettyChannelBuilder.forTarget(target) + NettyChannelBuilder channelBuilder = + NettyChannelBuilder.forTarget(target) .channelType(NioSocketChannel.class, InetSocketAddress.class) .directExecutor() .eventLoopGroup(eventGroup) - .usePlaintext() - .build(); + .usePlaintext(); + if (EXPERIMENTAL_ALTS_HANDSHAKER_KEEPALIVE_PARAMS) { + channelBuilder.keepAliveTime(10, TimeUnit.MINUTES).keepAliveTimeout(10, TimeUnit.SECONDS); + } + ManagedChannel channel = channelBuilder.build(); return new EventLoopHoldingChannel(channel, eventGroup); } diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java index e0343f83c51..9c51cf6a053 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java @@ -30,7 +30,6 @@ import io.grpc.SecurityLevel; import io.grpc.Status; import io.grpc.alts.internal.RpcProtocolVersionsUtil.RpcVersionsCheckResult; -import io.grpc.grpclb.GrpclbConstants; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalProtocolNegotiator; @@ -299,9 +298,7 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { isXdsDirectPath = isDirectPathCluster( grpcHandler.getEagAttributes().get(clusterNameAttrKey)); } - if (grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_ADDR_AUTHORITY) != null - || grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND) != null - || isXdsDirectPath) { + if (isXdsDirectPath) { TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority(), negotiationLogger); NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker); diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java b/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java index 007db9e1eed..2d6c322c1b1 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java @@ -80,7 +80,7 @@ public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityExce return true; } int remaining = bytes.remaining(); - // Call handshaker service to proceess the bytes. + // Call handshaker service to process the bytes. if (outputFrame == null) { checkState(!isClient, "Client handshaker should not process any frame at the beginning."); outputFrame = handshaker.startServerHandshake(bytes); diff --git a/alts/src/main/java/io/grpc/alts/internal/AsyncSemaphore.java b/alts/src/main/java/io/grpc/alts/internal/AsyncSemaphore.java index 3ccdcfc763a..a8251c7fbd3 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AsyncSemaphore.java +++ b/alts/src/main/java/io/grpc/alts/internal/AsyncSemaphore.java @@ -16,12 +16,12 @@ package io.grpc.alts.internal; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import java.util.LinkedList; import java.util.Queue; -import javax.annotation.concurrent.GuardedBy; /** Provides a semaphore primitive, without blocking waiting on permits. */ final class AsyncSemaphore { diff --git a/alts/src/test/java/io/grpc/alts/AltsContextUtilTest.java b/alts/src/test/java/io/grpc/alts/AltsContextUtilTest.java index 6fd2d840d45..675fa29fc99 100644 --- a/alts/src/test/java/io/grpc/alts/AltsContextUtilTest.java +++ b/alts/src/test/java/io/grpc/alts/AltsContextUtilTest.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.when; import io.grpc.Attributes; +import io.grpc.ClientCall; import io.grpc.ServerCall; import io.grpc.alts.AltsContext.SecurityLevel; import io.grpc.alts.internal.AltsInternalContext; @@ -37,27 +38,38 @@ /** Unit tests for {@link AltsContextUtil}. */ @RunWith(JUnit4.class) public class AltsContextUtilTest { - - private final ServerCall call = mock(ServerCall.class); - @Test public void check_noAttributeValue() { - when(call.getAttributes()).thenReturn(Attributes.newBuilder().build()); + assertFalse(AltsContextUtil.check(Attributes.newBuilder().build())); + } - assertFalse(AltsContextUtil.check(call)); + @Test + public void check_unexpectedAttributeValueType() { + assertFalse(AltsContextUtil.check(Attributes.newBuilder() + .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, new Object()) + .build())); } @Test - public void contains_unexpectedAttributeValueType() { + public void check_altsInternalContext() { + assertTrue(AltsContextUtil.check(Attributes.newBuilder() + .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, AltsInternalContext.getDefaultInstance()) + .build())); + } + + @Test + public void checkServer_altsInternalContext() { + ServerCall call = mock(ServerCall.class); when(call.getAttributes()).thenReturn(Attributes.newBuilder() - .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, new Object()) + .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, AltsInternalContext.getDefaultInstance()) .build()); - assertFalse(AltsContextUtil.check(call)); + assertTrue(AltsContextUtil.check(call)); } @Test - public void contains_altsInternalContext() { + public void checkClient_altsInternalContext() { + ClientCall call = mock(ClientCall.class); when(call.getAttributes()).thenReturn(Attributes.newBuilder() .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, AltsInternalContext.getDefaultInstance()) .build()); @@ -66,26 +78,57 @@ public void contains_altsInternalContext() { } @Test - public void from_altsInternalContext() { + public void createFrom_altsInternalContext() { HandshakerResult handshakerResult = HandshakerResult.newBuilder() .setPeerIdentity(Identity.newBuilder().setServiceAccount("remote@peer")) .setLocalIdentity(Identity.newBuilder().setServiceAccount("local@peer")) .build(); - when(call.getAttributes()).thenReturn(Attributes.newBuilder() - .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, new AltsInternalContext(handshakerResult)) - .build()); - AltsContext context = AltsContextUtil.createFrom(call); + AltsContext context = AltsContextUtil.createFrom(Attributes.newBuilder() + .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, new AltsInternalContext(handshakerResult)) + .build()); assertEquals("remote@peer", context.getPeerServiceAccount()); assertEquals("local@peer", context.getLocalServiceAccount()); assertEquals(SecurityLevel.INTEGRITY_AND_PRIVACY, context.getSecurityLevel()); } @Test(expected = IllegalArgumentException.class) - public void from_noAttributeValue() { - when(call.getAttributes()).thenReturn(Attributes.newBuilder().build()); + public void createFrom_noAttributeValue() { + AltsContextUtil.createFrom(Attributes.newBuilder().build()); + } - AltsContextUtil.createFrom(call); + @Test + public void createFromServer_altsInternalContext() { + HandshakerResult handshakerResult = + HandshakerResult.newBuilder() + .setPeerIdentity(Identity.newBuilder().setServiceAccount("remote@peer")) + .setLocalIdentity(Identity.newBuilder().setServiceAccount("local@peer")) + .build(); + + ServerCall call = mock(ServerCall.class); + when(call.getAttributes()).thenReturn(Attributes.newBuilder() + .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, new AltsInternalContext(handshakerResult)) + .build()); + + AltsContext context = AltsContextUtil.createFrom(call); + assertEquals("remote@peer", context.getPeerServiceAccount()); + } + + @Test + public void createFromClient_altsInternalContext() { + HandshakerResult handshakerResult = + HandshakerResult.newBuilder() + .setPeerIdentity(Identity.newBuilder().setServiceAccount("remote@peer")) + .setLocalIdentity(Identity.newBuilder().setServiceAccount("local@peer")) + .build(); + + ClientCall call = mock(ClientCall.class); + when(call.getAttributes()).thenReturn(Attributes.newBuilder() + .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, new AltsInternalContext(handshakerResult)) + .build()); + + AltsContext context = AltsContextUtil.createFrom(call); + assertEquals("remote@peer", context.getPeerServiceAccount()); } } diff --git a/alts/src/test/java/io/grpc/alts/DualCallCredentialsTest.java b/alts/src/test/java/io/grpc/alts/DualCallCredentialsTest.java new file mode 100644 index 00000000000..29646191be1 --- /dev/null +++ b/alts/src/test/java/io/grpc/alts/DualCallCredentialsTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.alts; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import io.grpc.Attributes; +import io.grpc.CallCredentials; +import io.grpc.CallCredentials.RequestInfo; +import io.grpc.MethodDescriptor; +import io.grpc.SecurityLevel; +import io.grpc.alts.internal.AltsInternalContext; +import io.grpc.alts.internal.AltsProtocolNegotiator; +import io.grpc.testing.TestMethodDescriptors; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Unit tests for {@link DualCallCredentials}. */ +@RunWith(JUnit4.class) +public class DualCallCredentialsTest { + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock CallCredentials tlsCallCredentials; + + @Mock CallCredentials altsCallCredentials; + + private static final String AUTHORITY = "testauthority"; + private static final SecurityLevel SECURITY_LEVEL = SecurityLevel.PRIVACY_AND_INTEGRITY; + + @Test + public void invokeTlsCallCredentials() { + DualCallCredentials callCredentials = + new DualCallCredentials(tlsCallCredentials, altsCallCredentials); + RequestInfo requestInfo = new RequestInfoImpl(false); + callCredentials.applyRequestMetadata(requestInfo, null, null); + + verify(altsCallCredentials, never()).applyRequestMetadata(any(), any(), any()); + verify(tlsCallCredentials, times(1)).applyRequestMetadata(requestInfo, null, null); + } + + @Test + public void invokeAltsCallCredentials() { + DualCallCredentials callCredentials = + new DualCallCredentials(tlsCallCredentials, altsCallCredentials); + RequestInfo requestInfo = new RequestInfoImpl(true); + callCredentials.applyRequestMetadata(requestInfo, null, null); + + verify(altsCallCredentials, times(1)).applyRequestMetadata(requestInfo, null, null); + verify(tlsCallCredentials, never()).applyRequestMetadata(any(), any(), any()); + } + + private static final class RequestInfoImpl extends CallCredentials.RequestInfo { + private Attributes attrs; + + RequestInfoImpl(boolean hasAltsContext) { + attrs = + hasAltsContext + ? Attributes.newBuilder() + .set( + AltsProtocolNegotiator.AUTH_CONTEXT_KEY, + AltsInternalContext.getDefaultInstance()) + .build() + : Attributes.EMPTY; + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return TestMethodDescriptors.voidMethod(); + } + + @Override + public SecurityLevel getSecurityLevel() { + return SECURITY_LEVEL; + } + + @Override + public String getAuthority() { + return AUTHORITY; + } + + @Override + public Attributes getTransportAttrs() { + return attrs; + } + } +} diff --git a/alts/src/test/java/io/grpc/alts/HandshakerServiceChannelTest.java b/alts/src/test/java/io/grpc/alts/HandshakerServiceChannelTest.java index a3937904cd7..221001157f1 100644 --- a/alts/src/test/java/io/grpc/alts/HandshakerServiceChannelTest.java +++ b/alts/src/test/java/io/grpc/alts/HandshakerServiceChannelTest.java @@ -67,6 +67,24 @@ public void sharedChannel_authority() { } } + @Test + public void getHandshakerTarget_nullEnvVar() { + assertThat(HandshakerServiceChannel.getHandshakerTarget(null)) + .isEqualTo("metadata.google.internal.:8080"); + } + + @Test + public void getHandshakerTarget_envVarWithPort() { + assertThat(HandshakerServiceChannel.getHandshakerTarget("169.254.169.254:80")) + .isEqualTo("169.254.169.254:8080"); + } + + @Test + public void getHandshakerTarget_envVarWithHostOnly() { + assertThat(HandshakerServiceChannel.getHandshakerTarget("169.254.169.254")) + .isEqualTo("169.254.169.254:8080"); + } + @Test public void resource_works() { Channel channel = resource.create(); diff --git a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java index 24392af75fd..d47607ed90f 100644 --- a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java @@ -202,8 +202,11 @@ public void operationComplete(ChannelFuture future) throws Exception { channel.flush(); // Capture the protected data written to the wire. - assertEquals(1, channel.outboundMessages().size()); - ByteBuf protectedData = channel.readOutbound(); + assertThat(channel.outboundMessages()).isNotEmpty(); + ByteBuf protectedData = channel.alloc().buffer(); + while (!channel.outboundMessages().isEmpty()) { + protectedData.writeBytes((ByteBuf) channel.readOutbound()); + } assertEquals(message.length(), writeCount.get()); // Read the protected message at the server and verify it matches the original message. @@ -327,16 +330,18 @@ public void doNotFlushEmptyBuffer() throws Exception { String message = "hello"; ByteBuf in = Unpooled.copiedBuffer(message, UTF_8); - assertEquals(0, protector.flushes.get()); + int flushes = protector.flushes.get(); Future done = channel.write(in); channel.flush(); + flushes++; done.get(5, TimeUnit.SECONDS); - assertEquals(1, protector.flushes.get()); + assertEquals(flushes, protector.flushes.get()); + // Flush does not propagate done = channel.write(Unpooled.EMPTY_BUFFER); channel.flush(); done.get(5, TimeUnit.SECONDS); - assertEquals(1, protector.flushes.get()); + assertEquals(flushes, protector.flushes.get()); } @Test diff --git a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java index 9a520720beb..14c19e554ae 100644 --- a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java @@ -29,7 +29,6 @@ import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ManagedChannel; -import io.grpc.grpclb.GrpclbConstants; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; @@ -95,13 +94,6 @@ public void tearDown() { @Nullable abstract Attributes.Key getClusterNameAttrKey(); - @Test - public void altsHandler_lbProvidedBackend() { - Attributes attrs = - Attributes.newBuilder().set(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND, true).build(); - subtest_altsHandler(attrs); - } - @Test public void tlsHandler_emptyAttributes() { subtest_tlsHandler(Attributes.EMPTY); diff --git a/android-interop-testing/build.gradle b/android-interop-testing/build.gradle index 9b3b021afce..b61d50a6763 100644 --- a/android-interop-testing/build.gradle +++ b/android-interop-testing/build.gradle @@ -7,11 +7,10 @@ description = 'gRPC: Android Integration Testing' repositories { google() - mavenCentral() } android { - namespace 'io.grpc.android.integrationtest' + namespace = 'io.grpc.android.integrationtest' sourceSets { main { java { @@ -34,15 +33,11 @@ android { defaultConfig { applicationId "io.grpc.android.integrationtest" - // Held back to 20 as Gradle fails to build at the 21 level. This is - // presumably a Gradle bug that can be revisited later. - // Maybe this issue: https://github.com/gradle/gradle/issues/20778 - minSdkVersion 20 + minSdkVersion 23 targetSdkVersion 33 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" - multiDexEnabled true } buildTypes { debug { minifyEnabled false } @@ -63,21 +58,21 @@ android { dependencies { implementation 'androidx.appcompat:appcompat:1.3.0' - implementation 'androidx.multidex:multidex:2.0.0' implementation libraries.androidx.annotation implementation 'com.google.android.gms:play-services-base:18.0.1' implementation project(':grpc-android'), + project(':grpc-api'), project(':grpc-core'), project(':grpc-census'), project(':grpc-okhttp'), project(':grpc-protobuf-lite'), project(':grpc-stub'), project(':grpc-testing'), - libraries.hdrhistogram, libraries.junit, libraries.truth, libraries.androidx.test.rules, + libraries.androidx.test.core, libraries.opencensus.contrib.grpc.metrics implementation (project(':grpc-services')) { @@ -85,10 +80,8 @@ dependencies { exclude group: 'com.google.guava' } - compileOnly libraries.javax.annotation - - androidTestImplementation 'androidx.test.ext:junit:1.1.3', - 'androidx.test:runner:1.4.0' + androidTestImplementation libraries.androidx.test.ext.junit, + libraries.androidx.test.runner } // Checkstyle doesn't run automatically with android @@ -119,6 +112,25 @@ tasks.withType(JavaCompile).configureEach { "|") } +// Workaround error seen with Gradle 8.14.3 and AGP 7.4.1 when building: +// ./gradlew clean :grpc-android-interop-testing:build -PskipAndroid=false \ +// -Pandroid.useAndroidX=true --no-build-cache +// +// Error message: +// +// Execution failed for task ':grpc-android-interop-testing:mergeExtDexDebug'. +// > Could not resolve all files for configuration ':grpc-android-interop-testing:debugRuntimeClasspath'. +// > Failed to transform opencensus-contrib-grpc-metrics-0.31.1.jar (io.opencensus:opencensus-contrib-grpc-metrics:0.31.1) to match attributes {artifactType=android-dex, asm-transformed-variant=NONE, dexing-enable-desugaring=true, dexing-enable-jacoco-instrumentation=false, dexing-is-debuggable=true, dexing-min-sdk=23, org.gradle.category=library, org.gradle.libraryelements=jar, org.gradle.status=release, org.gradle.usage=java-runtime}. +// > Could not resolve all files for configuration ':grpc-android-interop-testing:debugRuntimeClasspath'. +// > Failed to transform grpc-api-1.81.0-SNAPSHOT.jar (project :grpc-api) to match attributes {artifactType=android-classes-jar, org.gradle.category=library, org.gradle.dependency.bundling=external, org.gradle.jvm.version=8, org.gradle.libraryelements=jar, org.gradle.usage=java-runtime}. +// > Execution failed for IdentityTransform: grpc-java/api/build/libs/grpc-api-1.81.0-SNAPSHOT.jar. +// > File/directory does not exist: grpc-java/api/build/libs/grpc-api-1.81.0-SNAPSHOT.jar +tasks.configureEach { task -> + if (task.name.equals("mergeExtDexDebug")) { + dependsOn(':grpc-api:jar') + } +} + afterEvaluate { // Hack to workaround "Task ':grpc-android-interop-testing:extractIncludeDebugProto' uses this // output of task ':grpc-context:jar' without declaring an explicit or implicit dependency." The diff --git a/android-interop-testing/src/androidTest/AndroidManifest.xml b/android-interop-testing/src/androidTest/AndroidManifest.xml index b0507f10ab9..3cc0a29a85f 100644 --- a/android-interop-testing/src/androidTest/AndroidManifest.xml +++ b/android-interop-testing/src/androidTest/AndroidManifest.xml @@ -5,8 +5,7 @@ android:name="androidx.test.runner.AndroidJUnitRunner" android:targetPackage="io.grpc.android.integrationtest" /> - + diff --git a/android-interop-testing/src/androidTest/java/io/grpc/android/integrationtest/UdsChannelInteropTest.java b/android-interop-testing/src/androidTest/java/io/grpc/android/integrationtest/UdsChannelInteropTest.java index f5e54da5d4e..5b98665ba29 100644 --- a/android-interop-testing/src/androidTest/java/io/grpc/android/integrationtest/UdsChannelInteropTest.java +++ b/android-interop-testing/src/androidTest/java/io/grpc/android/integrationtest/UdsChannelInteropTest.java @@ -19,9 +19,9 @@ import static org.junit.Assert.assertEquals; import android.net.LocalSocketAddress.Namespace; -import androidx.test.InstrumentationRegistry; +import androidx.test.ext.junit.rules.ActivityScenarioRule; import androidx.test.ext.junit.runners.AndroidJUnit4; -import androidx.test.rule.ActivityTestRule; +import androidx.test.platform.app.InstrumentationRegistry; import io.grpc.Grpc; import io.grpc.InsecureServerCredentials; import io.grpc.Server; @@ -60,8 +60,8 @@ public class UdsChannelInteropTest { // Ensures Looper is initialized for tests running on API level 15. Otherwise instantiating an // AsyncTask throws an exception. @Rule - public ActivityTestRule activityRule = - new ActivityTestRule(TesterActivity.class); + public ActivityScenarioRule activityRule = + new ActivityScenarioRule<>(TesterActivity.class); @Before public void setUp() throws IOException { diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java index e030fde13e3..33b914bb4b3 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java @@ -7,9 +7,6 @@ * A service used to obtain stats for verifying LB behavior. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class LoadBalancerStatsServiceGrpc { @@ -92,6 +89,21 @@ public LoadBalancerStatsServiceStub newStub(io.grpc.Channel channel, io.grpc.Cal return LoadBalancerStatsServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static LoadBalancerStatsServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public LoadBalancerStatsServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerStatsServiceBlockingV2Stub(channel, callOptions); + } + }; + return LoadBalancerStatsServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -212,6 +224,46 @@ public void getClientAccumulatedStats(io.grpc.testing.integration.Messages.LoadB * A service used to obtain stats for verifying LB behavior. * */ + public static final class LoadBalancerStatsServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private LoadBalancerStatsServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected LoadBalancerStatsServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerStatsServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Gets the backend distribution for RPCs sent by a test client.
+     * 
+ */ + public io.grpc.testing.integration.Messages.LoadBalancerStatsResponse getClientStats(io.grpc.testing.integration.Messages.LoadBalancerStatsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetClientStatsMethod(), getCallOptions(), request); + } + + /** + *
+     * Gets the accumulated stats for RPCs sent by a test client.
+     * 
+ */ + public io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsResponse getClientAccumulatedStats(io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetClientAccumulatedStatsMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service LoadBalancerStatsService. + *
+   * A service used to obtain stats for verifying LB behavior.
+   * 
+ */ public static final class LoadBalancerStatsServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private LoadBalancerStatsServiceBlockingStub( diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java index e8726d5adc4..c99abcff7cb 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/metrics.proto") @io.grpc.stub.annotations.GrpcGenerated public final class MetricsServiceGrpc { @@ -89,6 +86,21 @@ public MetricsServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions c return MetricsServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static MetricsServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public MetricsServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new MetricsServiceBlockingV2Stub(channel, callOptions); + } + }; + return MetricsServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -199,6 +211,46 @@ public void getGauge(io.grpc.testing.integration.Metrics.GaugeRequest request, /** * A stub to allow clients to do synchronous rpc calls to service MetricsService. */ + public static final class MetricsServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private MetricsServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected MetricsServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new MetricsServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Returns the values of all the gauges that are currently being maintained by
+     * the service
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + getAllGauges(io.grpc.testing.integration.Metrics.EmptyMessage request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getGetAllGaugesMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns the value of one gauge
+     * 
+ */ + public io.grpc.testing.integration.Metrics.GaugeResponse getGauge(io.grpc.testing.integration.Metrics.GaugeRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetGaugeMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service MetricsService. + */ public static final class MetricsServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private MetricsServiceBlockingStub( diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java index 8ede6407cd0..fffcaad2df2 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java @@ -7,9 +7,6 @@ * A service used to control reconnect server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ReconnectServiceGrpc { @@ -92,6 +89,21 @@ public ReconnectServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return ReconnectServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ReconnectServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ReconnectServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReconnectServiceBlockingV2Stub(channel, callOptions); + } + }; + return ReconnectServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -200,6 +212,40 @@ public void stop(io.grpc.testing.integration.EmptyProtos.Empty request, * A service used to control reconnect server. * */ + public static final class ReconnectServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ReconnectServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ReconnectServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReconnectServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty start(io.grpc.testing.integration.Messages.ReconnectParams request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getStartMethod(), getCallOptions(), request); + } + + /** + */ + public io.grpc.testing.integration.Messages.ReconnectInfo stop(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getStopMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ReconnectService. + *
+   * A service used to control reconnect server.
+   * 
+ */ public static final class ReconnectServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ReconnectServiceBlockingStub( diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/TestServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/TestServiceGrpc.java index 01e2678a12f..1d7805e3a3f 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/TestServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/TestServiceGrpc.java @@ -8,9 +8,6 @@ * performance with various types of payload. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { @@ -273,6 +270,21 @@ public TestServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions call return TestServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -543,6 +555,125 @@ public void unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty requ * performance with various types of payload. * */ + public static final class TestServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * One empty request followed by one empty response.
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty emptyCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getEmptyCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by one response.
+     * 
+ */ + public io.grpc.testing.integration.Messages.SimpleResponse unaryCall(io.grpc.testing.integration.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by one response. Response has cache control
+     * headers set such that a caching HTTP proxy (such as GFE) can
+     * satisfy subsequent requests.
+     * 
+ */ + public io.grpc.testing.integration.Messages.SimpleResponse cacheableUnaryCall(io.grpc.testing.integration.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getCacheableUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by a sequence of responses (streamed download).
+     * The server returns the payload with client desired type and sizes.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingOutputCall(io.grpc.testing.integration.Messages.StreamingOutputCallRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamingOutputCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A sequence of requests followed by one response (streamed upload).
+     * The server returns the aggregated size of client payload as the result.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingInputCall() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getStreamingInputCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests with each request served by the server immediately.
+     * As one request could lead to multiple responses, this interface
+     * demonstrates the idea of full duplexing.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + fullDuplexCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getFullDuplexCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests followed by a sequence of responses.
+     * The server buffers all the client requests and then serves them in order. A
+     * stream of responses are returned to the client when the server starts with
+     * first request.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + halfDuplexCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getHalfDuplexCallMethod(), getCallOptions()); + } + + /** + *
+     * The test server will not implement this method. It will be used
+     * to test the behavior when clients call unimplemented methods.
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnimplementedCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestService. + *
+   * A simple service to test the various types of RPCs and experiment with
+   * performance with various types of payload.
+   * 
+ */ public static final class TestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestServiceBlockingStub( diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java index 743d68c3828..bec9b5a723a 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java @@ -8,9 +8,6 @@ * that case. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class UnimplementedServiceGrpc { @@ -63,6 +60,21 @@ public UnimplementedServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOpt return UnimplementedServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static UnimplementedServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public UnimplementedServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new UnimplementedServiceBlockingV2Stub(channel, callOptions); + } + }; + return UnimplementedServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -166,6 +178,37 @@ public void unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty requ * that case. * */ + public static final class UnimplementedServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private UnimplementedServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected UnimplementedServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new UnimplementedServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * A call that no server should implement
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnimplementedCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service UnimplementedService. + *
+   * A simple service NOT implemented at servers so clients can test for
+   * that case.
+   * 
+ */ public static final class UnimplementedServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private UnimplementedServiceBlockingStub( diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java index 61cfc19d29b..3453b6c01be 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java @@ -7,9 +7,6 @@ * A service to dynamically update the configuration of an xDS test client. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class XdsUpdateClientConfigureServiceGrpc { @@ -62,6 +59,21 @@ public XdsUpdateClientConfigureServiceStub newStub(io.grpc.Channel channel, io.g return XdsUpdateClientConfigureServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static XdsUpdateClientConfigureServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public XdsUpdateClientConfigureServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateClientConfigureServiceBlockingV2Stub(channel, callOptions); + } + }; + return XdsUpdateClientConfigureServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -161,6 +173,36 @@ public void configure(io.grpc.testing.integration.Messages.ClientConfigureReques * A service to dynamically update the configuration of an xDS test client. * */ + public static final class XdsUpdateClientConfigureServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private XdsUpdateClientConfigureServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected XdsUpdateClientConfigureServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateClientConfigureServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Update the tes client's configuration.
+     * 
+ */ + public io.grpc.testing.integration.Messages.ClientConfigureResponse configure(io.grpc.testing.integration.Messages.ClientConfigureRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getConfigureMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service XdsUpdateClientConfigureService. + *
+   * A service to dynamically update the configuration of an xDS test client.
+   * 
+ */ public static final class XdsUpdateClientConfigureServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private XdsUpdateClientConfigureServiceBlockingStub( diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java index 6ba9419dedf..fb5f2cdebc7 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java @@ -7,9 +7,6 @@ * A service to remotely control health status of an xDS test server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class XdsUpdateHealthServiceGrpc { @@ -92,6 +89,21 @@ public XdsUpdateHealthServiceStub newStub(io.grpc.Channel channel, io.grpc.CallO return XdsUpdateHealthServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static XdsUpdateHealthServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public XdsUpdateHealthServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateHealthServiceBlockingV2Stub(channel, callOptions); + } + }; + return XdsUpdateHealthServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -200,6 +212,40 @@ public void setNotServing(io.grpc.testing.integration.EmptyProtos.Empty request, * A service to remotely control health status of an xDS test server. * */ + public static final class XdsUpdateHealthServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private XdsUpdateHealthServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected XdsUpdateHealthServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateHealthServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty setServing(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSetServingMethod(), getCallOptions(), request); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty setNotServing(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSetNotServingMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service XdsUpdateHealthService. + *
+   * A service to remotely control health status of an xDS test server.
+   * 
+ */ public static final class XdsUpdateHealthServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private XdsUpdateHealthServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java index e030fde13e3..33b914bb4b3 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java @@ -7,9 +7,6 @@ * A service used to obtain stats for verifying LB behavior. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class LoadBalancerStatsServiceGrpc { @@ -92,6 +89,21 @@ public LoadBalancerStatsServiceStub newStub(io.grpc.Channel channel, io.grpc.Cal return LoadBalancerStatsServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static LoadBalancerStatsServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public LoadBalancerStatsServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerStatsServiceBlockingV2Stub(channel, callOptions); + } + }; + return LoadBalancerStatsServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -212,6 +224,46 @@ public void getClientAccumulatedStats(io.grpc.testing.integration.Messages.LoadB * A service used to obtain stats for verifying LB behavior. * */ + public static final class LoadBalancerStatsServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private LoadBalancerStatsServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected LoadBalancerStatsServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerStatsServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Gets the backend distribution for RPCs sent by a test client.
+     * 
+ */ + public io.grpc.testing.integration.Messages.LoadBalancerStatsResponse getClientStats(io.grpc.testing.integration.Messages.LoadBalancerStatsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetClientStatsMethod(), getCallOptions(), request); + } + + /** + *
+     * Gets the accumulated stats for RPCs sent by a test client.
+     * 
+ */ + public io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsResponse getClientAccumulatedStats(io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetClientAccumulatedStatsMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service LoadBalancerStatsService. + *
+   * A service used to obtain stats for verifying LB behavior.
+   * 
+ */ public static final class LoadBalancerStatsServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private LoadBalancerStatsServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java index e8726d5adc4..c99abcff7cb 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/metrics.proto") @io.grpc.stub.annotations.GrpcGenerated public final class MetricsServiceGrpc { @@ -89,6 +86,21 @@ public MetricsServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions c return MetricsServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static MetricsServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public MetricsServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new MetricsServiceBlockingV2Stub(channel, callOptions); + } + }; + return MetricsServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -199,6 +211,46 @@ public void getGauge(io.grpc.testing.integration.Metrics.GaugeRequest request, /** * A stub to allow clients to do synchronous rpc calls to service MetricsService. */ + public static final class MetricsServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private MetricsServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected MetricsServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new MetricsServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Returns the values of all the gauges that are currently being maintained by
+     * the service
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + getAllGauges(io.grpc.testing.integration.Metrics.EmptyMessage request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getGetAllGaugesMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns the value of one gauge
+     * 
+ */ + public io.grpc.testing.integration.Metrics.GaugeResponse getGauge(io.grpc.testing.integration.Metrics.GaugeRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetGaugeMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service MetricsService. + */ public static final class MetricsServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private MetricsServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java index 8ede6407cd0..fffcaad2df2 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java @@ -7,9 +7,6 @@ * A service used to control reconnect server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ReconnectServiceGrpc { @@ -92,6 +89,21 @@ public ReconnectServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return ReconnectServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ReconnectServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ReconnectServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReconnectServiceBlockingV2Stub(channel, callOptions); + } + }; + return ReconnectServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -200,6 +212,40 @@ public void stop(io.grpc.testing.integration.EmptyProtos.Empty request, * A service used to control reconnect server. * */ + public static final class ReconnectServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ReconnectServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ReconnectServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReconnectServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty start(io.grpc.testing.integration.Messages.ReconnectParams request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getStartMethod(), getCallOptions(), request); + } + + /** + */ + public io.grpc.testing.integration.Messages.ReconnectInfo stop(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getStopMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ReconnectService. + *
+   * A service used to control reconnect server.
+   * 
+ */ public static final class ReconnectServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ReconnectServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/TestServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/TestServiceGrpc.java index 01e2678a12f..1d7805e3a3f 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/TestServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/TestServiceGrpc.java @@ -8,9 +8,6 @@ * performance with various types of payload. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { @@ -273,6 +270,21 @@ public TestServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions call return TestServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -543,6 +555,125 @@ public void unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty requ * performance with various types of payload. * */ + public static final class TestServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * One empty request followed by one empty response.
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty emptyCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getEmptyCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by one response.
+     * 
+ */ + public io.grpc.testing.integration.Messages.SimpleResponse unaryCall(io.grpc.testing.integration.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by one response. Response has cache control
+     * headers set such that a caching HTTP proxy (such as GFE) can
+     * satisfy subsequent requests.
+     * 
+ */ + public io.grpc.testing.integration.Messages.SimpleResponse cacheableUnaryCall(io.grpc.testing.integration.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getCacheableUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by a sequence of responses (streamed download).
+     * The server returns the payload with client desired type and sizes.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingOutputCall(io.grpc.testing.integration.Messages.StreamingOutputCallRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamingOutputCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A sequence of requests followed by one response (streamed upload).
+     * The server returns the aggregated size of client payload as the result.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingInputCall() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getStreamingInputCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests with each request served by the server immediately.
+     * As one request could lead to multiple responses, this interface
+     * demonstrates the idea of full duplexing.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + fullDuplexCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getFullDuplexCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests followed by a sequence of responses.
+     * The server buffers all the client requests and then serves them in order. A
+     * stream of responses are returned to the client when the server starts with
+     * first request.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + halfDuplexCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getHalfDuplexCallMethod(), getCallOptions()); + } + + /** + *
+     * The test server will not implement this method. It will be used
+     * to test the behavior when clients call unimplemented methods.
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnimplementedCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestService. + *
+   * A simple service to test the various types of RPCs and experiment with
+   * performance with various types of payload.
+   * 
+ */ public static final class TestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java index 743d68c3828..bec9b5a723a 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java @@ -8,9 +8,6 @@ * that case. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class UnimplementedServiceGrpc { @@ -63,6 +60,21 @@ public UnimplementedServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOpt return UnimplementedServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static UnimplementedServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public UnimplementedServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new UnimplementedServiceBlockingV2Stub(channel, callOptions); + } + }; + return UnimplementedServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -166,6 +178,37 @@ public void unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty requ * that case. * */ + public static final class UnimplementedServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private UnimplementedServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected UnimplementedServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new UnimplementedServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * A call that no server should implement
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnimplementedCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service UnimplementedService. + *
+   * A simple service NOT implemented at servers so clients can test for
+   * that case.
+   * 
+ */ public static final class UnimplementedServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private UnimplementedServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java index 61cfc19d29b..3453b6c01be 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java @@ -7,9 +7,6 @@ * A service to dynamically update the configuration of an xDS test client. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class XdsUpdateClientConfigureServiceGrpc { @@ -62,6 +59,21 @@ public XdsUpdateClientConfigureServiceStub newStub(io.grpc.Channel channel, io.g return XdsUpdateClientConfigureServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static XdsUpdateClientConfigureServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public XdsUpdateClientConfigureServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateClientConfigureServiceBlockingV2Stub(channel, callOptions); + } + }; + return XdsUpdateClientConfigureServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -161,6 +173,36 @@ public void configure(io.grpc.testing.integration.Messages.ClientConfigureReques * A service to dynamically update the configuration of an xDS test client. * */ + public static final class XdsUpdateClientConfigureServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private XdsUpdateClientConfigureServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected XdsUpdateClientConfigureServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateClientConfigureServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Update the tes client's configuration.
+     * 
+ */ + public io.grpc.testing.integration.Messages.ClientConfigureResponse configure(io.grpc.testing.integration.Messages.ClientConfigureRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getConfigureMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service XdsUpdateClientConfigureService. + *
+   * A service to dynamically update the configuration of an xDS test client.
+   * 
+ */ public static final class XdsUpdateClientConfigureServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private XdsUpdateClientConfigureServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java index 6ba9419dedf..fb5f2cdebc7 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java @@ -7,9 +7,6 @@ * A service to remotely control health status of an xDS test server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class XdsUpdateHealthServiceGrpc { @@ -92,6 +89,21 @@ public XdsUpdateHealthServiceStub newStub(io.grpc.Channel channel, io.grpc.CallO return XdsUpdateHealthServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static XdsUpdateHealthServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public XdsUpdateHealthServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateHealthServiceBlockingV2Stub(channel, callOptions); + } + }; + return XdsUpdateHealthServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -200,6 +212,40 @@ public void setNotServing(io.grpc.testing.integration.EmptyProtos.Empty request, * A service to remotely control health status of an xDS test server. * */ + public static final class XdsUpdateHealthServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private XdsUpdateHealthServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected XdsUpdateHealthServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateHealthServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty setServing(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSetServingMethod(), getCallOptions(), request); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty setNotServing(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSetNotServingMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service XdsUpdateHealthService. + *
+   * A service to remotely control health status of an xDS test server.
+   * 
+ */ public static final class XdsUpdateHealthServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private XdsUpdateHealthServiceBlockingStub( diff --git a/android-interop-testing/src/main/AndroidManifest.xml b/android-interop-testing/src/main/AndroidManifest.xml index 35f3ee33a2b..08c139e5880 100644 --- a/android-interop-testing/src/main/AndroidManifest.xml +++ b/android-interop-testing/src/main/AndroidManifest.xml @@ -5,19 +5,19 @@ - + + android:theme="@style/Base.V7.Theme.AppCompat.Light"> + android:exported="true"> diff --git a/android-interop-testing/src/main/java/io/grpc/android/integrationtest/TesterActivity.java b/android-interop-testing/src/main/java/io/grpc/android/integrationtest/TesterActivity.java index fb5b35c42d5..17c7e24cbfa 100644 --- a/android-interop-testing/src/main/java/io/grpc/android/integrationtest/TesterActivity.java +++ b/android-interop-testing/src/main/java/io/grpc/android/integrationtest/TesterActivity.java @@ -121,7 +121,7 @@ private void startTest(String testCase) { ((InputMethodManager) getSystemService(Context.INPUT_METHOD_SERVICE)).hideSoftInputFromWindow( hostEdit.getWindowToken(), 0); enableButtons(false); - resultText.setText("Testing..."); + resultText.setText(R.string.testing_message); String host = hostEdit.getText().toString(); String portStr = portEdit.getText().toString(); diff --git a/android-interop-testing/src/main/res/layout/activity_tester.xml b/android-interop-testing/src/main/res/layout/activity_tester.xml index e25bd1bb6f6..042da6437c0 100644 --- a/android-interop-testing/src/main/res/layout/activity_tester.xml +++ b/android-interop-testing/src/main/res/layout/activity_tester.xml @@ -16,6 +16,7 @@ android:layout_weight="2" android:layout_width="0dp" android:layout_height="wrap_content" + android:inputType="text" android:hint="Enter Host" /> gRPC Integration Test + Testing… diff --git a/android/build.gradle b/android/build.gradle index 3b3bfa59b96..e94bf03ff37 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -7,20 +7,20 @@ plugins { description = 'gRPC: Android' android { - namespace 'io.grpc.android' + namespace = 'io.grpc.android' compileOptions { sourceCompatibility JavaVersion.VERSION_1_8 targetCompatibility JavaVersion.VERSION_1_8 } compileSdkVersion 34 defaultConfig { - minSdkVersion 21 + minSdkVersion 23 targetSdkVersion 33 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" } - lintOptions { abortOnError true } + lintOptions { abortOnError = true } publishing { singleVariant('release') { withSourcesJar() @@ -31,7 +31,6 @@ android { repositories { google() - mavenCentral() } dependencies { diff --git a/android/src/main/java/io/grpc/android/AndroidChannelBuilder.java b/android/src/main/java/io/grpc/android/AndroidChannelBuilder.java index 317b7a50b74..3a750e02795 100644 --- a/android/src/main/java/io/grpc/android/AndroidChannelBuilder.java +++ b/android/src/main/java/io/grpc/android/AndroidChannelBuilder.java @@ -28,6 +28,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.errorprone.annotations.InlineMe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.CallOptions; import io.grpc.ClientCall; import io.grpc.ConnectivityState; @@ -41,7 +42,6 @@ import io.grpc.internal.GrpcUtil; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Builds a {@link ManagedChannel} that, when provided with a {@link Context}, will automatically @@ -217,7 +217,6 @@ private void configureNetworkMonitoring() { connectivityManager.registerDefaultNetworkCallback(defaultNetworkCallback); unregisterRunnable = new Runnable() { - @TargetApi(Build.VERSION_CODES.LOLLIPOP) @Override public void run() { connectivityManager.unregisterNetworkCallback(defaultNetworkCallback); @@ -231,7 +230,6 @@ public void run() { context.registerReceiver(networkReceiver, networkIntentFilter); unregisterRunnable = new Runnable() { - @TargetApi(Build.VERSION_CODES.LOLLIPOP) @Override public void run() { context.unregisterReceiver(networkReceiver); @@ -325,7 +323,6 @@ public void onBlockedStatusChanged(Network network, boolean blocked) { /** Respond to network changes. Only used on API levels < 24. */ private class NetworkReceiver extends BroadcastReceiver { - private boolean isConnected = false; @SuppressWarnings("deprecation") @Override @@ -333,9 +330,8 @@ public void onReceive(Context context, Intent intent) { ConnectivityManager conn = (ConnectivityManager) context.getSystemService(Context.CONNECTIVITY_SERVICE); android.net.NetworkInfo networkInfo = conn.getActiveNetworkInfo(); - boolean wasConnected = isConnected; - isConnected = networkInfo != null && networkInfo.isConnected(); - if (isConnected && !wasConnected) { + + if (networkInfo != null && networkInfo.isConnected()) { delegate.enterIdle(); } } diff --git a/android/src/main/java/io/grpc/android/UdsChannelBuilder.java b/android/src/main/java/io/grpc/android/UdsChannelBuilder.java index e2dc7232378..6f03aa0ee5e 100644 --- a/android/src/main/java/io/grpc/android/UdsChannelBuilder.java +++ b/android/src/main/java/io/grpc/android/UdsChannelBuilder.java @@ -21,6 +21,7 @@ import io.grpc.ExperimentalApi; import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannelBuilder; +import io.grpc.internal.GrpcUtil; import java.lang.reflect.InvocationTargetException; import javax.annotation.Nullable; import javax.net.SocketFactory; @@ -68,17 +69,20 @@ public static ManagedChannelBuilder forPath(String path, Namespace namespace) throw new UnsupportedOperationException("OkHttpChannelBuilder not found on the classpath"); } try { - // Target 'dns:///localhost' is unused, but necessary as an argument for OkHttpChannelBuilder. + // Target 'dns:///127.0.0.1' is unused, but necessary as an argument for OkHttpChannelBuilder. + // An IP address is used instead of localhost to avoid a DNS lookup (see #11442). This should + // work even if IPv4 is unavailable, as the DNS resolver doesn't need working IPv4 to parse an + // IPv4 address. Unavailable IPv4 fails when we connect(), not at resolution time. // TLS is unsupported because Conscrypt assumes the platform Socket implementation to improve // performance by using the file descriptor directly. Object o = OKHTTP_CHANNEL_BUILDER_CLASS .getMethod("forTarget", String.class, ChannelCredentials.class) - .invoke(null, "dns:///localhost", InsecureChannelCredentials.create()); + .invoke(null, "dns:///127.0.0.1", InsecureChannelCredentials.create()); ManagedChannelBuilder builder = OKHTTP_CHANNEL_BUILDER_CLASS.cast(o); OKHTTP_CHANNEL_BUILDER_CLASS .getMethod("socketFactory", SocketFactory.class) .invoke(builder, new UdsSocketFactory(path, namespace)); - return builder; + return builder.proxyDetector(GrpcUtil.NOOP_PROXY_DETECTOR); } catch (IllegalAccessException e) { throw new RuntimeException("Failed to create OkHttpChannelBuilder", e); } catch (NoSuchMethodException e) { diff --git a/android/src/test/java/io/grpc/android/AndroidChannelBuilderTest.java b/android/src/test/java/io/grpc/android/AndroidChannelBuilderTest.java index 83367d93b32..c0884e4d7cf 100644 --- a/android/src/test/java/io/grpc/android/AndroidChannelBuilderTest.java +++ b/android/src/test/java/io/grpc/android/AndroidChannelBuilderTest.java @@ -152,12 +152,6 @@ public void networkChanges_api23() { .sendBroadcast(new Intent(ConnectivityManager.CONNECTIVITY_ACTION)); assertThat(delegateChannel.enterIdleCount).isEqualTo(1); - // The broadcast receiver may fire when the active network status has not actually changed - ApplicationProvider - .getApplicationContext() - .sendBroadcast(new Intent(ConnectivityManager.CONNECTIVITY_ACTION)); - assertThat(delegateChannel.enterIdleCount).isEqualTo(1); - // Drop the connection shadowOf(connectivityManager).setActiveNetworkInfo(null); ApplicationProvider diff --git a/api/BUILD.bazel b/api/BUILD.bazel index 6bf3375e9f0..6de00d6272d 100644 --- a/api/BUILD.bazel +++ b/api/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( @@ -6,7 +7,6 @@ java_library( "src/main/java/**/*.java", "src/context/java/**/*.java", ]), - javacopts = ["-Xep:DoNotCall:OFF"], # Remove once requiring Bazel 3.4.0+; allows non-final visibility = ["//visibility:public"], deps = [ artifact("com.google.code.findbugs:jsr305"), @@ -15,3 +15,21 @@ java_library( artifact("com.google.guava:guava"), ], ) + +java_library( + name = "test_fixtures", + testonly = 1, + srcs = glob([ + "src/testFixtures/java/io/grpc/**/*.java", + ]), + visibility = ["//xds:__pkg__"], + deps = [ + "//core", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("com.google.truth:truth"), + artifact("junit:junit"), + artifact("org.mockito:mockito-core"), + ], +) diff --git a/api/build.gradle b/api/build.gradle index 1d21c7bdcb6..745fa00b3f1 100644 --- a/api/build.gradle +++ b/api/build.gradle @@ -47,9 +47,22 @@ dependencies { testImplementation project(':grpc-core') testImplementation project(':grpc-testing') testImplementation libraries.guava.testlib + testImplementation libraries.truth - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } +} + +animalsniffer { + annotation = 'io.grpc.IgnoreJRERequirement' } tasks.named("javadoc").configure { @@ -60,6 +73,7 @@ tasks.named("javadoc").configure { exclude 'io/grpc/Internal?*.java' exclude 'io/grpc/MetricRecorder.java' exclude 'io/grpc/MetricSink.java' + exclude 'io/grpc/Uri.java' } tasks.named("sourcesJar").configure { diff --git a/api/src/context/java/io/grpc/Deadline.java b/api/src/context/java/io/grpc/Deadline.java index 62b803267a8..92eeba5ffce 100644 --- a/api/src/context/java/io/grpc/Deadline.java +++ b/api/src/context/java/io/grpc/Deadline.java @@ -16,8 +16,10 @@ package io.grpc; -import java.util.Arrays; +import static java.util.Objects.requireNonNull; + import java.util.Locale; +import java.util.Objects; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -33,7 +35,7 @@ * passed to the various components unambiguously. */ public final class Deadline implements Comparable { - private static final SystemTicker SYSTEM_TICKER = new SystemTicker(); + private static final Ticker SYSTEM_TICKER = new SystemTicker(); // nanoTime has a range of just under 300 years. Only allow up to 100 years in the past or future // to prevent wraparound as long as process runs for less than ~100 years. private static final long MAX_OFFSET = TimeUnit.DAYS.toNanos(100 * 365); @@ -91,7 +93,7 @@ public static Deadline after(long duration, TimeUnit units) { * @since 1.24.0 */ public static Deadline after(long duration, TimeUnit units, Ticker ticker) { - checkNotNull(units, "units"); + requireNonNull(units, "units"); return new Deadline(ticker, units.toNanos(duration), true); } @@ -191,8 +193,8 @@ public long timeRemaining(TimeUnit unit) { * @return {@link ScheduledFuture} which can be used to cancel execution of the task */ public ScheduledFuture runOnExpiration(Runnable task, ScheduledExecutorService scheduler) { - checkNotNull(task, "task"); - checkNotNull(scheduler, "scheduler"); + requireNonNull(task, "task"); + requireNonNull(scheduler, "scheduler"); return scheduler.schedule(task, deadlineNanos - ticker.nanoTime(), TimeUnit.NANOSECONDS); } @@ -225,37 +227,27 @@ public String toString() { @Override public int compareTo(Deadline that) { checkTicker(that); - long diff = this.deadlineNanos - that.deadlineNanos; - if (diff < 0) { - return -1; - } else if (diff > 0) { - return 1; - } - return 0; + return Long.compare(this.deadlineNanos, that.deadlineNanos); } @Override public int hashCode() { - return Arrays.asList(this.ticker, this.deadlineNanos).hashCode(); + return Objects.hash(this.ticker, this.deadlineNanos); } @Override - public boolean equals(final Object o) { - if (o == this) { + public boolean equals(final Object object) { + if (object == this) { return true; } - if (!(o instanceof Deadline)) { - return false; - } - - final Deadline other = (Deadline) o; - if (this.ticker == null ? other.ticker != null : this.ticker != other.ticker) { + if (!(object instanceof Deadline)) { return false; } - if (this.deadlineNanos != other.deadlineNanos) { + final Deadline that = (Deadline) object; + if (this.ticker == null ? that.ticker != null : this.ticker != that.ticker) { return false; } - return true; + return this.deadlineNanos == that.deadlineNanos; } /** @@ -275,24 +267,17 @@ public boolean equals(final Object o) { * @since 1.24.0 */ public abstract static class Ticker { - /** Returns the number of nanoseconds since this source's epoch. */ + /** Returns the number of nanoseconds elapsed since this ticker's reference point in time. */ public abstract long nanoTime(); } - private static class SystemTicker extends Ticker { + private static final class SystemTicker extends Ticker { @Override public long nanoTime() { return System.nanoTime(); } } - private static T checkNotNull(T reference, Object errorMessage) { - if (reference == null) { - throw new NullPointerException(String.valueOf(errorMessage)); - } - return reference; - } - private void checkTicker(Deadline other) { if (ticker != other.ticker) { throw new AssertionError( diff --git a/api/src/main/java/io/grpc/Attributes.java b/api/src/main/java/io/grpc/Attributes.java index de00e63554c..c8550d176b4 100644 --- a/api/src/main/java/io/grpc/Attributes.java +++ b/api/src/main/java/io/grpc/Attributes.java @@ -215,6 +215,7 @@ public int hashCode() { * The helper class to build an Attributes instance. */ public static final class Builder { + // Exactly one of base and newdata will be set private Attributes base; private IdentityHashMap, Object> newdata; @@ -225,8 +226,11 @@ private Builder(Attributes base) { private IdentityHashMap, Object> data(int size) { if (newdata == null) { - newdata = new IdentityHashMap<>(size); + newdata = new IdentityHashMap<>(base.data.size() + size); + newdata.putAll(base.data); + base = null; } + assert base == null; return newdata; } @@ -243,12 +247,11 @@ public Builder set(Key key, T value) { * @return this */ public Builder discard(Key key) { - if (base.data.containsKey(key)) { - IdentityHashMap, Object> newBaseData = new IdentityHashMap<>(base.data); - newBaseData.remove(key); - base = new Attributes(newBaseData); - } - if (newdata != null) { + if (base != null) { + if (base.data.containsKey(key)) { + data(0).remove(key); + } + } else { newdata.remove(key); } return this; @@ -264,11 +267,6 @@ public Builder setAll(Attributes other) { */ public Attributes build() { if (newdata != null) { - for (Map.Entry, Object> entry : base.data.entrySet()) { - if (!newdata.containsKey(entry.getKey())) { - newdata.put(entry.getKey(), entry.getValue()); - } - } base = new Attributes(newdata); newdata = null; } diff --git a/api/src/main/java/io/grpc/CallCredentials.java b/api/src/main/java/io/grpc/CallCredentials.java index 31b68b22dae..eb92a6f15fa 100644 --- a/api/src/main/java/io/grpc/CallCredentials.java +++ b/api/src/main/java/io/grpc/CallCredentials.java @@ -43,7 +43,7 @@ public abstract class CallCredentials { *

It is called for each individual RPC, within the {@link Context} of the call, before the * stream is about to be created on a transport. Implementations should not block in this * method. If metadata is not immediately available, e.g., needs to be fetched from network, the - * implementation may give the {@code applier} to an asynchronous task which will eventually call + * implementation may give the {@code appExecutor} an asynchronous task which will eventually call * the {@code applier}. The RPC proceeds only after the {@code applier} is called. * * @param requestInfo request-related information diff --git a/api/src/main/java/io/grpc/CallOptions.java b/api/src/main/java/io/grpc/CallOptions.java index 25c4df386a1..800bdfb6c90 100644 --- a/api/src/main/java/io/grpc/CallOptions.java +++ b/api/src/main/java/io/grpc/CallOptions.java @@ -17,16 +17,18 @@ package io.grpc; import static com.google.common.base.Preconditions.checkArgument; +import static io.grpc.TimeUtils.convertToNanos; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.annotation.concurrent.Immutable; @@ -176,6 +178,11 @@ public CallOptions withDeadlineAfter(long duration, TimeUnit unit) { return withDeadline(Deadline.after(duration, unit)); } + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11657") + public CallOptions withDeadlineAfter(Duration duration) { + return withDeadlineAfter(convertToNanos(duration), TimeUnit.NANOSECONDS); + } + /** * Returns the deadline or {@code null} if the deadline is not set. */ diff --git a/api/src/main/java/io/grpc/ClientCall.java b/api/src/main/java/io/grpc/ClientCall.java index df9e15001e1..c915c8beaac 100644 --- a/api/src/main/java/io/grpc/ClientCall.java +++ b/api/src/main/java/io/grpc/ClientCall.java @@ -67,7 +67,7 @@ * manner, and notifies gRPC library to receive additional response after one is consumed by * a fictional processResponse(). * - *

+ * 
  *   call = channel.newCall(bidiStreamingMethod, callOptions);
  *   listener = new ClientCall.Listener<FooResponse>() {
  *     @Override
diff --git a/api/src/main/java/io/grpc/ClientStreamTracer.java b/api/src/main/java/io/grpc/ClientStreamTracer.java
index 2f366b7404c..42e1fdfebea 100644
--- a/api/src/main/java/io/grpc/ClientStreamTracer.java
+++ b/api/src/main/java/io/grpc/ClientStreamTracer.java
@@ -132,12 +132,15 @@ public static final class StreamInfo {
     private final CallOptions callOptions;
     private final int previousAttempts;
     private final boolean isTransparentRetry;
+    private final boolean isHedging;
 
     StreamInfo(
-        CallOptions callOptions, int previousAttempts, boolean isTransparentRetry) {
+        CallOptions callOptions, int previousAttempts, boolean isTransparentRetry,
+        boolean isHedging) {
       this.callOptions = checkNotNull(callOptions, "callOptions");
       this.previousAttempts = previousAttempts;
       this.isTransparentRetry = isTransparentRetry;
+      this.isHedging = isHedging;
     }
 
     /**
@@ -165,6 +168,15 @@ public boolean isTransparentRetry() {
       return isTransparentRetry;
     }
 
+    /**
+     * Whether the stream is hedging.
+     *
+     * @since 1.74.0
+     */
+    public boolean isHedging() {
+      return isHedging;
+    }
+
     /**
      * Converts this StreamInfo into a new Builder.
      *
@@ -174,7 +186,9 @@ public Builder toBuilder() {
       return new Builder()
           .setCallOptions(callOptions)
           .setPreviousAttempts(previousAttempts)
-          .setIsTransparentRetry(isTransparentRetry);
+          .setIsTransparentRetry(isTransparentRetry)
+          .setIsHedging(isHedging);
+
     }
 
     /**
@@ -192,6 +206,7 @@ public String toString() {
           .add("callOptions", callOptions)
           .add("previousAttempts", previousAttempts)
           .add("isTransparentRetry", isTransparentRetry)
+          .add("isHedging", isHedging)
           .toString();
     }
 
@@ -204,6 +219,7 @@ public static final class Builder {
       private CallOptions callOptions = CallOptions.DEFAULT;
       private int previousAttempts;
       private boolean isTransparentRetry;
+      private boolean isHedging;
 
       Builder() {
       }
@@ -236,11 +252,21 @@ public Builder setIsTransparentRetry(boolean isTransparentRetry) {
         return this;
       }
 
+      /**
+       * Sets whether the stream is hedging.
+       *
+       * @since 1.74.0
+       */
+      public Builder setIsHedging(boolean isHedging) {
+        this.isHedging = isHedging;
+        return this;
+      }
+
       /**
        * Builds a new StreamInfo.
        */
       public StreamInfo build() {
-        return new StreamInfo(callOptions, previousAttempts, isTransparentRetry);
+        return new StreamInfo(callOptions, previousAttempts, isTransparentRetry, isHedging);
       }
     }
   }
diff --git a/api/src/main/java/io/grpc/ConfiguratorRegistry.java b/api/src/main/java/io/grpc/ConfiguratorRegistry.java
index b2efcc1cff4..19d6703d308 100644
--- a/api/src/main/java/io/grpc/ConfiguratorRegistry.java
+++ b/api/src/main/java/io/grpc/ConfiguratorRegistry.java
@@ -16,10 +16,10 @@
 
 package io.grpc;
 
+import com.google.errorprone.annotations.concurrent.GuardedBy;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import javax.annotation.concurrent.GuardedBy;
 
 /**
  * A registry for {@link Configurator} instances.
@@ -33,9 +33,9 @@ final class ConfiguratorRegistry {
   @GuardedBy("this")
   private boolean wasConfiguratorsSet;
   @GuardedBy("this")
-  private boolean configFrozen;
-  @GuardedBy("this")
   private List configurators = Collections.emptyList();
+  @GuardedBy("this")
+  private int configuratorsCallCountBeforeSet = 0;
 
   ConfiguratorRegistry() {}
 
@@ -56,11 +56,10 @@ public static synchronized ConfiguratorRegistry getDefaultRegistry() {
    * @throws IllegalStateException if this method is called more than once
    */
   public synchronized void setConfigurators(List configurators) {
-    if (configFrozen) {
+    if (wasConfiguratorsSet) {
       throw new IllegalStateException("Configurators are already set");
     }
     this.configurators = Collections.unmodifiableList(new ArrayList<>(configurators));
-    configFrozen = true;
     wasConfiguratorsSet = true;
   }
 
@@ -68,10 +67,20 @@ public synchronized void setConfigurators(List configura
    * Returns a list of the configurators in this registry.
    */
   public synchronized List getConfigurators() {
-    configFrozen = true;
+    if (!wasConfiguratorsSet) {
+      configuratorsCallCountBeforeSet++;
+    }
     return configurators;
   }
 
+  /**
+   * Returns the number of times getConfigurators() was called before
+   * setConfigurators() was successfully invoked.
+   */
+  public synchronized int getConfiguratorsCallCountBeforeSet() {
+    return configuratorsCallCountBeforeSet;
+  }
+
   public synchronized boolean wasSetConfiguratorsCalled() {
     return wasConfiguratorsSet;
   }
diff --git a/api/src/main/java/io/grpc/ConnectivityState.java b/api/src/main/java/io/grpc/ConnectivityState.java
index 677039b2517..a7407efb2e9 100644
--- a/api/src/main/java/io/grpc/ConnectivityState.java
+++ b/api/src/main/java/io/grpc/ConnectivityState.java
@@ -20,7 +20,7 @@
  * The connectivity states.
  *
  * @see 
- * more information
+ *     more information
  */
 @ExperimentalApi("https://github.com/grpc/grpc-java/issues/4359")
 public enum ConnectivityState {
diff --git a/api/src/main/java/io/grpc/EquivalentAddressGroup.java b/api/src/main/java/io/grpc/EquivalentAddressGroup.java
index 4b3db006684..18151e88aba 100644
--- a/api/src/main/java/io/grpc/EquivalentAddressGroup.java
+++ b/api/src/main/java/io/grpc/EquivalentAddressGroup.java
@@ -50,6 +50,20 @@ public final class EquivalentAddressGroup {
   @ExperimentalApi("https://github.com/grpc/grpc-java/issues/6138")
   public static final Attributes.Key ATTR_AUTHORITY_OVERRIDE =
       Attributes.Key.create("io.grpc.EquivalentAddressGroup.ATTR_AUTHORITY_OVERRIDE");
+  /**
+   * The name of the locality that this EquivalentAddressGroup is in.
+   */
+  public static final Attributes.Key ATTR_LOCALITY_NAME =
+      Attributes.Key.create("io.grpc.EquivalentAddressGroup.LOCALITY");
+  /**
+   * Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32.
+   * Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is
+   * twice that of another endpoint, it is intended to receive twice the load.
+   */
+  @Attr
+  static final Attributes.Key ATTR_WEIGHT =
+      Attributes.Key.create("io.grpc.EquivalentAddressGroup.ATTR_WEIGHT");
+
   private final List addrs;
   private final Attributes attrs;
 
@@ -108,7 +122,9 @@ public Attributes getAttributes() {
 
   @Override
   public String toString() {
-    // TODO(zpencer): Summarize return value if addr is very large
+    // EquivalentAddressGroup is intended to contain a small number of addresses for the same
+    // endpoint(e.g., IPv4/IPv6). Aggregating many groups into a single EquivalentAddressGroup
+    // is no longer done, so this no longer needs summarization.
     return "[" + addrs + "/" + attrs + "]";
   }
 
diff --git a/api/src/main/java/io/grpc/FeatureFlags.java b/api/src/main/java/io/grpc/FeatureFlags.java
new file mode 100644
index 00000000000..0e414ed7b31
--- /dev/null
+++ b/api/src/main/java/io/grpc/FeatureFlags.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright 2026 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Strings;
+
+class FeatureFlags {
+  private static boolean enableRfc3986Uris = getFlag("GRPC_ENABLE_RFC3986_URIS", false);
+
+  /** Whether to parse targets as RFC 3986 URIs (true), or use {@link java.net.URI} (false). */
+  @VisibleForTesting
+  static boolean setRfc3986UrisEnabled(boolean value) {
+    boolean prevValue = enableRfc3986Uris;
+    enableRfc3986Uris = value;
+    return prevValue;
+  }
+
+  /** Whether to parse targets as RFC 3986 URIs (true), or use {@link java.net.URI} (false). */
+  static boolean getRfc3986UrisEnabled() {
+    return enableRfc3986Uris;
+  }
+
+  static boolean getFlag(String envVarName, boolean enableByDefault) {
+    String envVar = System.getenv(envVarName);
+    if (envVar == null) {
+      envVar = System.getProperty(envVarName);
+    }
+    if (envVar != null) {
+      envVar = envVar.trim();
+    }
+    if (enableByDefault) {
+      return Strings.isNullOrEmpty(envVar) || Boolean.parseBoolean(envVar);
+    } else {
+      return !Strings.isNullOrEmpty(envVar) && Boolean.parseBoolean(envVar);
+    }
+  }
+
+  private FeatureFlags() {}
+}
diff --git a/api/src/main/java/io/grpc/ForwardingChannelBuilder2.java b/api/src/main/java/io/grpc/ForwardingChannelBuilder2.java
index 7f21a57ec80..78fe730d91a 100644
--- a/api/src/main/java/io/grpc/ForwardingChannelBuilder2.java
+++ b/api/src/main/java/io/grpc/ForwardingChannelBuilder2.java
@@ -263,6 +263,12 @@ protected T addMetricSink(MetricSink metricSink) {
     return thisT();
   }
 
+  @Override
+  public  T setNameResolverArg(NameResolver.Args.Key key, X value) {
+    delegate().setNameResolverArg(key, value);
+    return thisT();
+  }
+
   /**
    * Returns the {@link ManagedChannel} built by the delegate by default. Overriding method can
    * return different value.
diff --git a/api/src/main/java/io/grpc/ForwardingServerBuilder.java b/api/src/main/java/io/grpc/ForwardingServerBuilder.java
index 9cef7cfa331..d1f183dd824 100644
--- a/api/src/main/java/io/grpc/ForwardingServerBuilder.java
+++ b/api/src/main/java/io/grpc/ForwardingServerBuilder.java
@@ -201,6 +201,12 @@ public Server build() {
     return delegate().build();
   }
 
+  @Override
+  public T addMetricSink(MetricSink metricSink) {
+    delegate().addMetricSink(metricSink);
+    return thisT();
+  }
+
   @Override
   public String toString() {
     return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString();
diff --git a/api/src/main/java/io/grpc/Grpc.java b/api/src/main/java/io/grpc/Grpc.java
index baa9f5f0ab6..a45c613fd18 100644
--- a/api/src/main/java/io/grpc/Grpc.java
+++ b/api/src/main/java/io/grpc/Grpc.java
@@ -56,6 +56,13 @@ private Grpc() {
   public static final Attributes.Key TRANSPORT_ATTR_SSL_SESSION =
       Attributes.Key.create("io.grpc.Grpc.TRANSPORT_ATTR_SSL_SESSION");
 
+  /**
+   * The value for the custom label of per-RPC metrics. Defaults to empty string when unset. Must
+   * not be set to {@code null}.
+   */
+  public static final CallOptions.Key CALL_OPTION_CUSTOM_LABEL =
+      CallOptions.Key.createWithDefault("io.grpc.Grpc.CALL_OPTION_CUSTOM_LABEL", "");
+
   /**
    * Annotation for transport attributes. It follows the annotation semantics defined
    * by {@link Attributes}.
diff --git a/api/src/main/java/io/grpc/HttpConnectProxiedSocketAddress.java b/api/src/main/java/io/grpc/HttpConnectProxiedSocketAddress.java
index d59c53db1d1..0df8dc452c1 100644
--- a/api/src/main/java/io/grpc/HttpConnectProxiedSocketAddress.java
+++ b/api/src/main/java/io/grpc/HttpConnectProxiedSocketAddress.java
@@ -23,6 +23,9 @@
 import com.google.common.base.Objects;
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
 import javax.annotation.Nullable;
 
 /**
@@ -33,6 +36,8 @@ public final class HttpConnectProxiedSocketAddress extends ProxiedSocketAddress
 
   private final SocketAddress proxyAddress;
   private final InetSocketAddress targetAddress;
+  @SuppressWarnings("serial")
+  private final Map headers;
   @Nullable
   private final String username;
   @Nullable
@@ -41,6 +46,7 @@ public final class HttpConnectProxiedSocketAddress extends ProxiedSocketAddress
   private HttpConnectProxiedSocketAddress(
       SocketAddress proxyAddress,
       InetSocketAddress targetAddress,
+      Map headers,
       @Nullable String username,
       @Nullable String password) {
     checkNotNull(proxyAddress, "proxyAddress");
@@ -53,6 +59,7 @@ private HttpConnectProxiedSocketAddress(
     }
     this.proxyAddress = proxyAddress;
     this.targetAddress = targetAddress;
+    this.headers = headers;
     this.username = username;
     this.password = password;
   }
@@ -87,6 +94,14 @@ public InetSocketAddress getTargetAddress() {
     return targetAddress;
   }
 
+  /**
+   * Returns the custom HTTP headers to be sent during the HTTP CONNECT handshake.
+   */
+  @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12479")
+  public Map getHeaders() {
+    return headers;
+  }
+
   @Override
   public boolean equals(Object o) {
     if (!(o instanceof HttpConnectProxiedSocketAddress)) {
@@ -95,13 +110,14 @@ public boolean equals(Object o) {
     HttpConnectProxiedSocketAddress that = (HttpConnectProxiedSocketAddress) o;
     return Objects.equal(proxyAddress, that.proxyAddress)
         && Objects.equal(targetAddress, that.targetAddress)
+        && Objects.equal(headers, that.headers)
         && Objects.equal(username, that.username)
         && Objects.equal(password, that.password);
   }
 
   @Override
   public int hashCode() {
-    return Objects.hashCode(proxyAddress, targetAddress, username, password);
+    return Objects.hashCode(proxyAddress, targetAddress, username, password, headers);
   }
 
   @Override
@@ -109,6 +125,7 @@ public String toString() {
     return MoreObjects.toStringHelper(this)
         .add("proxyAddr", proxyAddress)
         .add("targetAddr", targetAddress)
+        .add("headers", headers)
         .add("username", username)
         // Intentionally mask out password
         .add("hasPassword", password != null)
@@ -129,6 +146,7 @@ public static final class Builder {
 
     private SocketAddress proxyAddress;
     private InetSocketAddress targetAddress;
+    private Map headers = Collections.emptyMap();
     @Nullable
     private String username;
     @Nullable
@@ -153,6 +171,18 @@ public Builder setTargetAddress(InetSocketAddress targetAddress) {
       return this;
     }
 
+    /**
+     * Sets custom HTTP headers to be sent during the HTTP CONNECT handshake. This is an optional
+     * field. The headers will be sent in addition to any authentication headers (if username and
+     * password are set).
+     */
+    @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12479")
+    public Builder setHeaders(Map headers) {
+      this.headers = Collections.unmodifiableMap(
+          new HashMap<>(checkNotNull(headers, "headers")));
+      return this;
+    }
+
     /**
      * Sets the username used to connect to the proxy.  This is an optional field and can be {@code
      * null}.
@@ -175,7 +205,8 @@ public Builder setPassword(@Nullable String password) {
      * Creates an {@code HttpConnectProxiedSocketAddress}.
      */
     public HttpConnectProxiedSocketAddress build() {
-      return new HttpConnectProxiedSocketAddress(proxyAddress, targetAddress, username, password);
+      return new HttpConnectProxiedSocketAddress(
+          proxyAddress, targetAddress, headers, username, password);
     }
   }
 }
diff --git a/api/src/main/java/io/grpc/IgnoreJRERequirement.java b/api/src/main/java/io/grpc/IgnoreJRERequirement.java
new file mode 100644
index 00000000000..2db406c5953
--- /dev/null
+++ b/api/src/main/java/io/grpc/IgnoreJRERequirement.java
@@ -0,0 +1,30 @@
+/*
+ * Copyright 2024 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc;
+
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Target;
+
+/**
+ * Disables Animal Sniffer's signature checking. This is our own package-private version to avoid
+ * dependening on animalsniffer-annotations.
+ *
+ * 

FIELD is purposefully not supported, as Android wouldn't be able to ignore a field. Instead, + * the entire class would need to be avoided on Android. + */ +@Target({ElementType.METHOD, ElementType.CONSTRUCTOR, ElementType.TYPE}) +@interface IgnoreJRERequirement {} diff --git a/api/src/main/java/io/grpc/InternalConfigSelector.java b/api/src/main/java/io/grpc/InternalConfigSelector.java index 38856f440b4..a63009361d4 100644 --- a/api/src/main/java/io/grpc/InternalConfigSelector.java +++ b/api/src/main/java/io/grpc/InternalConfigSelector.java @@ -35,7 +35,7 @@ public abstract class InternalConfigSelector { = Attributes.Key.create("internal:io.grpc.config-selector"); // Use PickSubchannelArgs for SelectConfigArgs for now. May change over time. - /** Selects the config for an PRC. */ + /** Selects the config for an RPC. */ public abstract Result selectConfig(LoadBalancer.PickSubchannelArgs args); public static final class Result { diff --git a/api/src/main/java/io/grpc/InternalConfiguratorRegistry.java b/api/src/main/java/io/grpc/InternalConfiguratorRegistry.java index b495800ff13..f567dab74c4 100644 --- a/api/src/main/java/io/grpc/InternalConfiguratorRegistry.java +++ b/api/src/main/java/io/grpc/InternalConfiguratorRegistry.java @@ -48,4 +48,8 @@ public static void configureServerBuilder(ServerBuilder serverBuilder) { public static boolean wasSetConfiguratorsCalled() { return ConfiguratorRegistry.getDefaultRegistry().wasSetConfiguratorsCalled(); } + + public static int getConfiguratorsCallCountBeforeSet() { + return ConfiguratorRegistry.getDefaultRegistry().getConfiguratorsCallCountBeforeSet(); + } } diff --git a/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java b/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java new file mode 100644 index 00000000000..d4bed4d81bc --- /dev/null +++ b/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java @@ -0,0 +1,29 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +@Internal +public final class InternalEquivalentAddressGroup { + private InternalEquivalentAddressGroup() {} + + /** + * Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32. + * Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is + * twice that of another endpoint, it is intended to receive twice the load. + */ + public static final Attributes.Key ATTR_WEIGHT = EquivalentAddressGroup.ATTR_WEIGHT; +} diff --git a/api/src/main/java/io/grpc/InternalFeatureFlags.java b/api/src/main/java/io/grpc/InternalFeatureFlags.java new file mode 100644 index 00000000000..a1e771a7571 --- /dev/null +++ b/api/src/main/java/io/grpc/InternalFeatureFlags.java @@ -0,0 +1,41 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import com.google.common.annotations.VisibleForTesting; + +/** Global variables that govern major changes to the behavior of more than one grpc module. */ +@Internal +public class InternalFeatureFlags { + + /** Whether to parse targets as RFC 3986 URIs (true), or use {@link java.net.URI} (false). */ + @VisibleForTesting + public static boolean setRfc3986UrisEnabled(boolean value) { + return FeatureFlags.setRfc3986UrisEnabled(value); + } + + /** Whether to parse targets as RFC 3986 URIs (true), or use {@link java.net.URI} (false). */ + public static boolean getRfc3986UrisEnabled() { + return FeatureFlags.getRfc3986UrisEnabled(); + } + + public static boolean getFlag(String envVarName, boolean enableByDefault) { + return FeatureFlags.getFlag(envVarName, enableByDefault); + } + + private InternalFeatureFlags() {} +} diff --git a/api/src/main/java/io/grpc/InternalServiceProviders.java b/api/src/main/java/io/grpc/InternalServiceProviders.java index c72e01db67a..debc786a82a 100644 --- a/api/src/main/java/io/grpc/InternalServiceProviders.java +++ b/api/src/main/java/io/grpc/InternalServiceProviders.java @@ -17,7 +17,9 @@ package io.grpc; import com.google.common.annotations.VisibleForTesting; +import java.util.Iterator; import java.util.List; +import java.util.ServiceLoader; @Internal public final class InternalServiceProviders { @@ -27,12 +29,17 @@ private InternalServiceProviders() { /** * Accessor for method. */ - public static T load( + @Deprecated + public static List loadAll( Class klass, - Iterable> hardcoded, + Iterable> hardCodedClasses, ClassLoader classLoader, PriorityAccessor priorityAccessor) { - return ServiceProviders.load(klass, hardcoded, classLoader, priorityAccessor); + return loadAll( + klass, + ServiceLoader.load(klass, classLoader).iterator(), + () -> hardCodedClasses, + priorityAccessor); } /** @@ -40,10 +47,10 @@ public static T load( */ public static List loadAll( Class klass, - Iterable> hardCodedClasses, - ClassLoader classLoader, + Iterator serviceLoader, + Supplier>> hardCodedClasses, PriorityAccessor priorityAccessor) { - return ServiceProviders.loadAll(klass, hardCodedClasses, classLoader, priorityAccessor); + return ServiceProviders.loadAll(klass, serviceLoader, hardCodedClasses::get, priorityAccessor); } /** @@ -71,4 +78,8 @@ public static boolean isAndroid(ClassLoader cl) { } public interface PriorityAccessor extends ServiceProviders.PriorityAccessor {} + + public interface Supplier { + T get(); + } } diff --git a/api/src/main/java/io/grpc/InternalStatus.java b/api/src/main/java/io/grpc/InternalStatus.java index b6549bb435f..56df1decf38 100644 --- a/api/src/main/java/io/grpc/InternalStatus.java +++ b/api/src/main/java/io/grpc/InternalStatus.java @@ -38,12 +38,11 @@ private InternalStatus() {} public static final Metadata.Key CODE_KEY = Status.CODE_KEY; /** - * Create a new {@link StatusRuntimeException} with the internal option of skipping the filling - * of the stack trace. + * Create a new {@link StatusRuntimeException} skipping the filling of the stack trace. */ @Internal - public static final StatusRuntimeException asRuntimeException(Status status, - @Nullable Metadata trailers, boolean fillInStackTrace) { - return new StatusRuntimeException(status, trailers, fillInStackTrace); + public static StatusRuntimeException asRuntimeExceptionWithoutStacktrace(Status status, + @Nullable Metadata trailers) { + return new InternalStatusRuntimeException(status, trailers); } } diff --git a/api/src/main/java/io/grpc/InternalStatusRuntimeException.java b/api/src/main/java/io/grpc/InternalStatusRuntimeException.java new file mode 100644 index 00000000000..6090b701f0b --- /dev/null +++ b/api/src/main/java/io/grpc/InternalStatusRuntimeException.java @@ -0,0 +1,39 @@ +/* + * Copyright 2015 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import javax.annotation.Nullable; + +/** + * StatusRuntimeException without stack trace, implemented as a subclass, as the + * {@code String, Throwable, boolean, boolean} constructor is not available in the supported + * version of Android. + * + * @see StatusRuntimeException + */ +class InternalStatusRuntimeException extends StatusRuntimeException { + private static final long serialVersionUID = 0; + + public InternalStatusRuntimeException(Status status, @Nullable Metadata trailers) { + super(status, trailers); + } + + @Override + public synchronized Throwable fillInStackTrace() { + return this; + } +} diff --git a/api/src/main/java/io/grpc/InternalTcpMetrics.java b/api/src/main/java/io/grpc/InternalTcpMetrics.java new file mode 100644 index 00000000000..3dd89b6f76c --- /dev/null +++ b/api/src/main/java/io/grpc/InternalTcpMetrics.java @@ -0,0 +1,98 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * TCP Metrics defined to be shared across transport implementations. + * These metrics and their definitions are specified in + * gRFC + * A80. + */ +@Internal +public final class InternalTcpMetrics { + + private InternalTcpMetrics() { + } + + private static final List OPTIONAL_LABELS = Arrays.asList( + "network.local.address", + "network.local.port", + "network.peer.address", + "network.peer.port"); + + public static final DoubleHistogramMetricInstrument MIN_RTT_INSTRUMENT = + MetricInstrumentRegistry.getDefaultRegistry() + .registerDoubleHistogram( + "grpc.tcp.min_rtt", + "Minimum round-trip time of a TCP connection", + "s", + Collections.emptyList(), + Collections.emptyList(), + OPTIONAL_LABELS, + false); + + public static final LongCounterMetricInstrument CONNECTIONS_CREATED_INSTRUMENT = + MetricInstrumentRegistry + .getDefaultRegistry() + .registerLongCounter( + "grpc.tcp.connections_created", + "The total number of TCP connections established.", + "{connection}", + Collections.emptyList(), + OPTIONAL_LABELS, + false); + + public static final LongUpDownCounterMetricInstrument CONNECTION_COUNT_INSTRUMENT = + MetricInstrumentRegistry + .getDefaultRegistry() + .registerLongUpDownCounter( + "grpc.tcp.connection_count", + "The current number of active TCP connections.", + "{connection}", + Collections.emptyList(), + OPTIONAL_LABELS, + false); + + public static final LongCounterMetricInstrument PACKETS_RETRANSMITTED_INSTRUMENT = + MetricInstrumentRegistry + .getDefaultRegistry() + .registerLongCounter( + "grpc.tcp.packets_retransmitted", + "The total number of packets retransmitted for all TCP connections.", + "{packet}", + Collections.emptyList(), + OPTIONAL_LABELS, + false); + + public static final LongCounterMetricInstrument RECURRING_RETRANSMITS_INSTRUMENT = + MetricInstrumentRegistry + .getDefaultRegistry() + .registerLongCounter( + "grpc.tcp.recurring_retransmits", + "The total number of times the retransmit timer " + + "popped for all TCP connections.", + "{timeout}", + Collections.emptyList(), + OPTIONAL_LABELS, + false); + +} diff --git a/api/src/main/java/io/grpc/InternalTimeUtils.java b/api/src/main/java/io/grpc/InternalTimeUtils.java new file mode 100644 index 00000000000..ef8022f53c5 --- /dev/null +++ b/api/src/main/java/io/grpc/InternalTimeUtils.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.time.Duration; + +@Internal +public final class InternalTimeUtils { + public static long convert(Duration duration) { + return TimeUtils.convertToNanos(duration); + } +} diff --git a/api/src/main/java/io/grpc/LoadBalancer.java b/api/src/main/java/io/grpc/LoadBalancer.java index 0fbce5fa5be..3187ae8ef1b 100644 --- a/api/src/main/java/io/grpc/LoadBalancer.java +++ b/api/src/main/java/io/grpc/LoadBalancer.java @@ -121,6 +121,12 @@ public abstract class LoadBalancer { HEALTH_CONSUMER_LISTENER_ARG_KEY = LoadBalancer.CreateSubchannelArgs.Key.create("internal:health-check-consumer-listener"); + @Internal + public static final LoadBalancer.CreateSubchannelArgs.Key + DISABLE_SUBCHANNEL_RECONNECT_KEY = + LoadBalancer.CreateSubchannelArgs.Key.createWithDefault( + "internal:disable-subchannel-reconnect", Boolean.FALSE); + @Internal public static final Attributes.Key HAS_HEALTH_PRODUCER_LISTENER_KEY = @@ -150,15 +156,16 @@ public String toString() { private int recursionCount; /** - * Handles newly resolved server groups and metadata attributes from name resolution system. - * {@code servers} contained in {@link EquivalentAddressGroup} should be considered equivalent - * but may be flattened into a single list if needed. - * - *

Implementations should not modify the given {@code servers}. + * Handles newly resolved addresses and metadata attributes from name resolution system. + * Addresses in {@link EquivalentAddressGroup} should be considered equivalent but may be + * flattened into a single list if needed. * * @param resolvedAddresses the resolved server addresses, attributes, and config. * @since 1.21.0 + * + * @deprecated Use instead {@link #acceptResolvedAddresses(ResolvedAddresses)} */ + @Deprecated public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (recursionCount++ == 0) { // Note that the information about the addresses actually being accepted will be lost @@ -173,12 +180,10 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { * EquivalentAddressGroup} addresses should be considered equivalent but may be flattened into a * single list if needed. * - *

Implementations can choose to reject the given addresses by returning {@code false}. - * - *

Implementations should not modify the given {@code addresses}. + * @param resolvedAddresses the resolved server addresses, attributes, and config + * @return {@code Status.OK} if the resolved addresses were accepted, otherwise an error to report + * to the name resolver * - * @param resolvedAddresses the resolved server addresses, attributes, and config. - * @return {@code true} if the resolved addresses were accepted. {@code false} if rejected. * @since 1.49.0 */ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { @@ -206,7 +211,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { * * @since 1.21.0 */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771") + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11657") public static final class ResolvedAddresses { private final List addresses; @NameResolver.ResolutionResultAttr @@ -412,7 +417,16 @@ public void handleSubchannelState( * *

This method should always return a constant value. It's not specified when this will be * called. + * + *

Note that this method is only called when implementing {@code handleResolvedAddresses()} + * instead of {@code acceptResolvedAddresses()}. + * + * @deprecated Instead of overwriting this and {@code handleResolvedAddresses()}, only + * overwrite {@code acceptResolvedAddresses()} which indicates if the addresses provided + * by the name resolver are acceptable with the {@code boolean} return value. */ + @Deprecated + @SuppressWarnings("InlineMeSuggester") public boolean canHandleEmptyAddressListFromNameResolution() { return false; } @@ -446,18 +460,6 @@ public abstract static class SubchannelPicker { * @since 1.3.0 */ public abstract PickResult pickSubchannel(PickSubchannelArgs args); - - /** - * Tries to establish connections now so that the upcoming RPC may then just pick a ready - * connection without having to connect first. - * - *

No-op if unsupported. - * - * @deprecated override {@link LoadBalancer#requestConnection} instead. - * @since 1.11.0 - */ - @Deprecated - public void requestConnection() {} } /** @@ -546,6 +548,7 @@ public static final class PickResult { private final Status status; // True if the result is created by withDrop() private final boolean drop; + @Nullable private final String authorityOverride; private PickResult( @Nullable Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory, @@ -554,6 +557,17 @@ private PickResult( this.streamTracerFactory = streamTracerFactory; this.status = checkNotNull(status, "status"); this.drop = drop; + this.authorityOverride = null; + } + + private PickResult( + @Nullable Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory, + Status status, boolean drop, @Nullable String authorityOverride) { + this.subchannel = subchannel; + this.streamTracerFactory = streamTracerFactory; + this.status = checkNotNull(status, "status"); + this.drop = drop; + this.authorityOverride = authorityOverride; } /** @@ -626,6 +640,8 @@ private PickResult( * stream is created at all in some cases. * @since 1.3.0 */ + // TODO(shivaspeaks): Need to deprecate old APIs and create new ones, + // per https://github.com/grpc/grpc-java/issues/12662. public static PickResult withSubchannel( Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory) { return new PickResult( @@ -633,6 +649,19 @@ public static PickResult withSubchannel( false); } + /** + * Same as {@code withSubchannel(subchannel, streamTracerFactory)} but with an authority name + * to override in the host header. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11656") + public static PickResult withSubchannel( + Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory, + @Nullable String authorityOverride) { + return new PickResult( + checkNotNull(subchannel, "subchannel"), streamTracerFactory, Status.OK, + false, authorityOverride); + } + /** * Equivalent to {@code withSubchannel(subchannel, null)}. * @@ -642,6 +671,28 @@ public static PickResult withSubchannel(Subchannel subchannel) { return withSubchannel(subchannel, null); } + /** + * Creates a new {@code PickResult} with the given {@code subchannel}, + * but retains all other properties from this {@code PickResult}. + * + * @since 1.80.0 + */ + public PickResult copyWithSubchannel(Subchannel subchannel) { + return new PickResult(checkNotNull(subchannel, "subchannel"), streamTracerFactory, + status, drop, authorityOverride); + } + + /** + * Creates a new {@code PickResult} with the given {@code streamTracerFactory}, + * but retains all other properties from this {@code PickResult}. + * + * @since 1.80.0 + */ + public PickResult copyWithStreamTracerFactory( + @Nullable ClientStreamTracer.Factory streamTracerFactory) { + return new PickResult(subchannel, streamTracerFactory, status, drop, authorityOverride); + } + /** * A decision to report a connectivity error to the RPC. If the RPC is {@link * CallOptions#withWaitForReady wait-for-ready}, it will stay buffered. Otherwise, it will fail @@ -676,6 +727,13 @@ public static PickResult withNoResult() { return NO_RESULT; } + /** Returns the authority override if any. */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11656") + @Nullable + public String getAuthorityOverride() { + return authorityOverride; + } + /** * The Subchannel if this result was created by {@link #withSubchannel withSubchannel()}, or * null otherwise. @@ -730,6 +788,7 @@ public String toString() { .add("streamTracerFactory", streamTracerFactory) .add("status", status) .add("drop", drop) + .add("authority-override", authorityOverride) .toString(); } @@ -828,9 +887,11 @@ public String toString() { @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771") public static final class Builder { + private static final Object[][] EMPTY_CUSTOM_OPTIONS = new Object[0][2]; + private List addrs; private Attributes attrs = Attributes.EMPTY; - private Object[][] customOptions = new Object[0][2]; + private Object[][] customOptions = EMPTY_CUSTOM_OPTIONS; Builder() { } @@ -994,8 +1055,8 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { } /** - * Out-of-band channel for LoadBalancer’s own RPC needs, e.g., talking to an external - * load-balancer service. + * Create an out-of-band channel for the LoadBalancer’s own RPC needs, e.g., talking to an + * external load-balancer service. * *

The LoadBalancer is responsible for closing unused OOB channels, and closing all OOB * channels within {@link #shutdown}. @@ -1005,7 +1066,12 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { public abstract ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority); /** - * Accept a list of EAG for multiple authorities: https://github.com/grpc/grpc-java/issues/4618 + * Create an out-of-band channel for the LoadBalancer's own RPC needs, e.g., talking to an + * external load-balancer service. This version of the method allows multiple EAGs, so different + * addresses can have different authorities. + * + *

The LoadBalancer is responsible for closing unused OOB channels, and closing all OOB + * channels within {@link #shutdown}. * */ public ManagedChannel createOobChannel(List eag, String authority) { @@ -1157,6 +1223,10 @@ public void ignoreRefreshNameResolutionCheck() { * Returns a {@link SynchronizationContext} that runs tasks in the same Synchronization Context * as that the callback methods on the {@link LoadBalancer} interface are run in. * + *

Work added to the synchronization context might not run immediately, so LB implementations + * must be careful to ensure that any assumptions still hold when it is executed. In particular, + * the LB might have been shut down or subchannels might have changed state. + * *

Pro-tip: in order to call {@link SynchronizationContext#schedule}, you need to provide a * {@link ScheduledExecutorService}. {@link #getScheduledExecutorService} is provided for your * convenience. diff --git a/api/src/main/java/io/grpc/LoadBalancerProvider.java b/api/src/main/java/io/grpc/LoadBalancerProvider.java index bb4c574211e..7dc30d6baaf 100644 --- a/api/src/main/java/io/grpc/LoadBalancerProvider.java +++ b/api/src/main/java/io/grpc/LoadBalancerProvider.java @@ -81,7 +81,7 @@ public abstract class LoadBalancerProvider extends LoadBalancer.Factory { * @return a tuple of the fully parsed and validated balancer configuration, else the Status. * @since 1.20.0 * @see - * A24-lb-policy-config.md + * A24-lb-policy-config.md */ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawLoadBalancingPolicyConfig) { return UNKNOWN_CONFIG; diff --git a/api/src/main/java/io/grpc/LoadBalancerRegistry.java b/api/src/main/java/io/grpc/LoadBalancerRegistry.java index f6b69f978b8..a8fbc102f5f 100644 --- a/api/src/main/java/io/grpc/LoadBalancerRegistry.java +++ b/api/src/main/java/io/grpc/LoadBalancerRegistry.java @@ -26,6 +26,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -42,7 +43,6 @@ public final class LoadBalancerRegistry { private static final Logger logger = Logger.getLogger(LoadBalancerRegistry.class.getName()); private static LoadBalancerRegistry instance; - private static final Iterable> HARDCODED_CLASSES = getHardCodedClasses(); private final LinkedHashSet allProviders = new LinkedHashSet<>(); @@ -101,8 +101,10 @@ public static synchronized LoadBalancerRegistry getDefaultRegistry() { if (instance == null) { List providerList = ServiceProviders.loadAll( LoadBalancerProvider.class, - HARDCODED_CLASSES, - LoadBalancerProvider.class.getClassLoader(), + ServiceLoader + .load(LoadBalancerProvider.class, LoadBalancerProvider.class.getClassLoader()) + .iterator(), + LoadBalancerRegistry::getHardCodedClasses, new LoadBalancerPriorityAccessor()); instance = new LoadBalancerRegistry(); for (LoadBalancerProvider provider : providerList) { diff --git a/api/src/main/java/io/grpc/LongUpDownCounterMetricInstrument.java b/api/src/main/java/io/grpc/LongUpDownCounterMetricInstrument.java new file mode 100644 index 00000000000..07e099cde5d --- /dev/null +++ b/api/src/main/java/io/grpc/LongUpDownCounterMetricInstrument.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.List; + +/** + * Represents a long-valued up down counter metric instrument. + */ +@Internal +public final class LongUpDownCounterMetricInstrument extends PartialMetricInstrument { + public LongUpDownCounterMetricInstrument(int index, String name, String description, String unit, + List requiredLabelKeys, + List optionalLabelKeys, + boolean enableByDefault) { + super(index, name, description, unit, requiredLabelKeys, optionalLabelKeys, enableByDefault); + } +} \ No newline at end of file diff --git a/api/src/main/java/io/grpc/ManagedChannelBuilder.java b/api/src/main/java/io/grpc/ManagedChannelBuilder.java index 6e30d8eae04..3f370ab3003 100644 --- a/api/src/main/java/io/grpc/ManagedChannelBuilder.java +++ b/api/src/main/java/io/grpc/ManagedChannelBuilder.java @@ -374,9 +374,17 @@ public T maxInboundMetadataSize(int bytes) { * notice when they are causing excessive load. Clients are strongly encouraged to use only as * small of a value as necessary. * + *

When the channel implementation supports TCP_USER_TIMEOUT, enabling keepalive will also + * enable TCP_USER_TIMEOUT for the connection. This requires all sent packets to receive + * a TCP acknowledgement before the keepalive timeout. The keepalive time is not used for + * TCP_USER_TIMEOUT, except as a signal to enable the feature. grpc-netty supports + * TCP_USER_TIMEOUT on Linux platforms supported by netty-transport-native-epoll. + * * @throws UnsupportedOperationException if unsupported * @see gRFC A8 * Client-side Keepalive + * @see gRFC A18 + * TCP User Timeout * @since 1.7.0 */ public T keepAliveTime(long keepAliveTime, TimeUnit timeUnit) { @@ -393,6 +401,8 @@ public T keepAliveTime(long keepAliveTime, TimeUnit timeUnit) { * @throws UnsupportedOperationException if unsupported * @see gRFC A8 * Client-side Keepalive + * @see gRFC A18 + * TCP User Timeout * @since 1.7.0 */ public T keepAliveTimeout(long keepAliveTimeout, TimeUnit timeUnit) { @@ -633,6 +643,23 @@ protected T addMetricSink(MetricSink metricSink) { throw new UnsupportedOperationException(); } + /** + * Provides a "custom" argument for the {@link NameResolver}, if applicable, replacing any 'value' + * previously provided for 'key'. + * + *

NB: If the selected {@link NameResolver} does not understand 'key', or target URI resolution + * isn't needed at all, your custom argument will be silently ignored. + * + *

See {@link NameResolver.Args#getArg(NameResolver.Args.Key)} for more. + * + * @param key identifies the argument in a type-safe manner + * @param value the argument itself + * @return this + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1770") + public T setNameResolverArg(NameResolver.Args.Key key, X value) { + throw new UnsupportedOperationException(); + } /** * Builds a channel using the given parameters. diff --git a/api/src/main/java/io/grpc/ManagedChannelRegistry.java b/api/src/main/java/io/grpc/ManagedChannelRegistry.java index 31f874b8094..ec47b325ffc 100644 --- a/api/src/main/java/io/grpc/ManagedChannelRegistry.java +++ b/api/src/main/java/io/grpc/ManagedChannelRegistry.java @@ -18,6 +18,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.net.SocketAddress; import java.net.URI; import java.net.URISyntaxException; @@ -28,9 +29,9 @@ import java.util.Comparator; import java.util.LinkedHashSet; import java.util.List; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -100,8 +101,10 @@ public static synchronized ManagedChannelRegistry getDefaultRegistry() { if (instance == null) { List providerList = ServiceProviders.loadAll( ManagedChannelProvider.class, - getHardCodedClasses(), - ManagedChannelProvider.class.getClassLoader(), + ServiceLoader + .load(ManagedChannelProvider.class, ManagedChannelProvider.class.getClassLoader()) + .iterator(), + ManagedChannelRegistry::getHardCodedClasses, new ManagedChannelPriorityAccessor()); instance = new ManagedChannelRegistry(); for (ManagedChannelProvider provider : providerList) { @@ -160,8 +163,11 @@ ManagedChannelBuilder newChannelBuilder(NameResolverRegistry nameResolverRegi String target, ChannelCredentials creds) { NameResolverProvider nameResolverProvider = null; try { - URI uri = new URI(target); - nameResolverProvider = nameResolverRegistry.getProviderForScheme(uri.getScheme()); + String scheme = + FeatureFlags.getRfc3986UrisEnabled() + ? Uri.parse(target).getScheme() + : new URI(target).getScheme(); + nameResolverProvider = nameResolverRegistry.getProviderForScheme(scheme); } catch (URISyntaxException ignore) { // bad URI found, just ignore and continue } diff --git a/api/src/main/java/io/grpc/Metadata.java b/api/src/main/java/io/grpc/Metadata.java index fba2659776b..8a958d127df 100644 --- a/api/src/main/java/io/grpc/Metadata.java +++ b/api/src/main/java/io/grpc/Metadata.java @@ -22,6 +22,8 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; import com.google.common.io.BaseEncoding; import com.google.common.io.ByteStreams; import java.io.ByteArrayInputStream; @@ -32,8 +34,6 @@ import java.util.Arrays; import java.util.BitSet; import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -325,7 +325,7 @@ public Set keys() { if (isEmpty()) { return Collections.emptySet(); } - Set ks = new HashSet<>(size); + Set ks = Sets.newHashSetWithExpectedSize(size); for (int i = 0; i < size; i++) { ks.add(new String(name(i), 0 /* hibyte */)); } @@ -526,7 +526,7 @@ public void merge(Metadata other) { public void merge(Metadata other, Set> keys) { Preconditions.checkNotNull(other, "other"); // Use ByteBuffer for equals and hashCode. - Map> asciiKeys = new HashMap<>(keys.size()); + Map> asciiKeys = Maps.newHashMapWithExpectedSize(keys.size()); for (Key key : keys) { asciiKeys.put(ByteBuffer.wrap(key.asciiName()), key); } diff --git a/api/src/main/java/io/grpc/MethodDescriptor.java b/api/src/main/java/io/grpc/MethodDescriptor.java index 1bfaccb4201..a02eb840deb 100644 --- a/api/src/main/java/io/grpc/MethodDescriptor.java +++ b/api/src/main/java/io/grpc/MethodDescriptor.java @@ -20,9 +20,9 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; import java.io.InputStream; import java.util.concurrent.atomic.AtomicReferenceArray; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.annotation.concurrent.Immutable; diff --git a/api/src/main/java/io/grpc/MetricInstrumentRegistry.java b/api/src/main/java/io/grpc/MetricInstrumentRegistry.java index a61ac058a61..ce0f8f1b5cb 100644 --- a/api/src/main/java/io/grpc/MetricInstrumentRegistry.java +++ b/api/src/main/java/io/grpc/MetricInstrumentRegistry.java @@ -21,12 +21,12 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Strings; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; -import javax.annotation.concurrent.GuardedBy; /** * A registry for globally registered metric instruments. @@ -144,6 +144,47 @@ public LongCounterMetricInstrument registerLongCounter(String name, } } + /** + * Registers a new Long Up Down Counter metric instrument. + * + * @param name the name of the metric + * @param description a description of the metric + * @param unit the unit of measurement for the metric + * @param requiredLabelKeys a list of required label keys + * @param optionalLabelKeys a list of optional label keys + * @param enableByDefault whether the metric should be enabled by default + * @return the newly created LongUpDownCounterMetricInstrument + * @throws IllegalStateException if a metric with the same name already exists + */ + public LongUpDownCounterMetricInstrument registerLongUpDownCounter(String name, + String description, + String unit, + List requiredLabelKeys, + List optionalLabelKeys, + boolean enableByDefault) { + checkArgument(!Strings.isNullOrEmpty(name), "missing metric name"); + checkNotNull(description, "description"); + checkNotNull(unit, "unit"); + checkNotNull(requiredLabelKeys, "requiredLabelKeys"); + checkNotNull(optionalLabelKeys, "optionalLabelKeys"); + synchronized (lock) { + if (registeredMetricNames.contains(name)) { + throw new IllegalStateException("Metric with name " + name + " already exists"); + } + int index = nextAvailableMetricIndex; + if (index + 1 == metricInstruments.length) { + resizeMetricInstruments(); + } + LongUpDownCounterMetricInstrument instrument = new LongUpDownCounterMetricInstrument( + index, name, description, unit, requiredLabelKeys, optionalLabelKeys, + enableByDefault); + metricInstruments[index] = instrument; + registeredMetricNames.add(name); + nextAvailableMetricIndex += 1; + return instrument; + } + } + /** * Registers a new Double Histogram metric instrument. * diff --git a/api/src/main/java/io/grpc/MetricRecorder.java b/api/src/main/java/io/grpc/MetricRecorder.java index d418dcbf590..897c28011cd 100644 --- a/api/src/main/java/io/grpc/MetricRecorder.java +++ b/api/src/main/java/io/grpc/MetricRecorder.java @@ -50,7 +50,7 @@ default void addDoubleCounter(DoubleCounterMetricInstrument metricInstrument, do * Adds a value for a long valued counter metric instrument. * * @param metricInstrument The counter metric instrument to add the value against. - * @param value The value to add. + * @param value The value to add. MUST be non-negative. * @param requiredLabelValues A list of required label values for the metric. * @param optionalLabelValues A list of additional, optional label values for the metric. */ @@ -66,6 +66,29 @@ default void addLongCounter(LongCounterMetricInstrument metricInstrument, long v metricInstrument.getOptionalLabelKeys().size()); } + /** + * Adds a value for a long valued up down counter metric instrument. + * + * @param metricInstrument The counter metric instrument to add the value against. + * @param value The value to add. May be positive, negative or zero. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void addLongUpDownCounter(LongUpDownCounterMetricInstrument metricInstrument, + long value, + List requiredLabelValues, + List optionalLabelValues) { + checkArgument(requiredLabelValues != null + && requiredLabelValues.size() == metricInstrument.getRequiredLabelKeys().size(), + "Incorrect number of required labels provided. Expected: %s", + metricInstrument.getRequiredLabelKeys().size()); + checkArgument(optionalLabelValues != null + && optionalLabelValues.size() == metricInstrument.getOptionalLabelKeys().size(), + "Incorrect number of optional labels provided. Expected: %s", + metricInstrument.getOptionalLabelKeys().size()); + } + + /** * Records a value for a double-precision histogram metric instrument. * diff --git a/api/src/main/java/io/grpc/MetricSink.java b/api/src/main/java/io/grpc/MetricSink.java index 0f56b1acb73..ce5d3822520 100644 --- a/api/src/main/java/io/grpc/MetricSink.java +++ b/api/src/main/java/io/grpc/MetricSink.java @@ -65,12 +65,26 @@ default void addDoubleCounter(DoubleCounterMetricInstrument metricInstrument, do * Adds a value for a long valued counter metric associated with specified metric instrument. * * @param metricInstrument The counter metric instrument identifies metric measure to add. - * @param value The value to record. + * @param value The value to record. MUST be non-negative. * @param requiredLabelValues A list of required label values for the metric. * @param optionalLabelValues A list of additional, optional label values for the metric. */ default void addLongCounter(LongCounterMetricInstrument metricInstrument, long value, - List requiredLabelValues, List optionalLabelValues) { + List requiredLabelValues, List optionalLabelValues) { + } + + /** + * Adds a value for a long valued up down counter metric associated with specified metric + * instrument. + * + * @param metricInstrument The counter metric instrument identifies metric measure to add. + * @param value The value to record. May be positive, negative or zero. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void addLongUpDownCounter(LongUpDownCounterMetricInstrument metricInstrument, long value, + List requiredLabelValues, + List optionalLabelValues) { } /** diff --git a/api/src/main/java/io/grpc/NameResolver.java b/api/src/main/java/io/grpc/NameResolver.java index bfb9c2a43a1..e44a26309ae 100644 --- a/api/src/main/java/io/grpc/NameResolver.java +++ b/api/src/main/java/io/grpc/NameResolver.java @@ -20,19 +20,21 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; +import com.google.common.base.MoreObjects.ToStringHelper; import com.google.common.base.Objects; import com.google.errorprone.annotations.InlineMe; import java.lang.annotation.Documented; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.net.URI; -import java.util.ArrayList; import java.util.Collections; +import java.util.IdentityHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import javax.annotation.Nullable; +import javax.annotation.concurrent.Immutable; import javax.annotation.concurrent.ThreadSafe; /** @@ -95,7 +97,14 @@ public void onError(Status error) { @Override public void onResult(ResolutionResult resolutionResult) { - listener.onAddresses(resolutionResult.getAddresses(), resolutionResult.getAttributes()); + StatusOr> addressesOrError = + resolutionResult.getAddressesOrError(); + if (addressesOrError.hasValue()) { + listener.onAddresses(addressesOrError.getValue(), + resolutionResult.getAttributes()); + } else { + listener.onError(addressesOrError.getStatus()); + } } }); } @@ -149,6 +158,10 @@ public abstract static class Factory { * cannot be resolved by this factory. The decision should be solely based on the scheme of the * URI. * + *

This method will eventually be deprecated and removed as part of a migration from {@code + * java.net.URI} to {@code io.grpc.Uri}. Implementations will override {@link + * #newNameResolver(Uri, Args)} instead. + * * @param targetUri the target URI to be resolved, whose scheme must not be {@code null} * @param args other information that may be useful * @@ -156,6 +169,37 @@ public abstract static class Factory { */ public abstract NameResolver newNameResolver(URI targetUri, final Args args); + /** + * Creates a {@link NameResolver} for the given target URI. + * + *

Implementations return {@code null} if 'targetUri' cannot be resolved by this factory. The + * decision should be solely based on the target's scheme. + * + *

All {@link NameResolver.Factory} implementations should override this method, as it will + * eventually replace {@link #newNameResolver(URI, Args)}. For backwards compatibility, this + * default implementation delegates to {@link #newNameResolver(URI, Args)} if 'targetUri' can be + * converted to a java.net.URI. + * + *

NB: Conversion is not always possible, for example {@code scheme:#frag} is a valid {@link + * Uri} but not a valid {@link URI} because its path is empty. The default implementation throws + * IllegalArgumentException in these cases. + * + * @param targetUri the target URI to be resolved + * @param args other information that may be useful + * @throws IllegalArgumentException if targetUri does not have the expected form + * @since 1.79 + */ + public NameResolver newNameResolver(Uri targetUri, final Args args) { + // Not every io.grpc.Uri can be converted but in the ordinary ManagedChannel creation flow, + // any IllegalArgumentException thrown here would happened anyway, just earlier. That's + // because parse/toString is transparent so java.net.URI#create here sees the original target + // string just like it did before the io.grpc.Uri migration. + // + // Throwing IAE shouldn't surprise non-framework callers either. After all, many existing + // Factory impls are picky about targetUri and throw IAE when it doesn't look how they expect. + return newNameResolver(URI.create(targetUri.toString()), args); + } + /** * Returns the default scheme, which will be used to construct a URI when {@link * ManagedChannelBuilder#forTarget(String)} is given an authority string instead of a compliant @@ -218,19 +262,26 @@ public abstract static class Listener2 implements Listener { @Override @Deprecated @InlineMe( - replacement = "this.onResult(ResolutionResult.newBuilder().setAddresses(servers)" - + ".setAttributes(attributes).build())", - imports = "io.grpc.NameResolver.ResolutionResult") + replacement = "this.onResult(ResolutionResult.newBuilder().setAddressesOrError(" + + "StatusOr.fromValue(servers)).setAttributes(attributes).build())", + imports = {"io.grpc.NameResolver.ResolutionResult", "io.grpc.StatusOr"}) public final void onAddresses( List servers, @ResolutionResultAttr Attributes attributes) { // TODO(jihuncho) need to promote Listener2 if we want to use ConfigOrError + // Calling onResult and not onResult2 because onResult2 can only be called from a + // synchronization context. onResult( - ResolutionResult.newBuilder().setAddresses(servers).setAttributes(attributes).build()); + ResolutionResult.newBuilder().setAddressesOrError( + StatusOr.fromValue(servers)).setAttributes(attributes).build()); } /** * Handles updates on resolved addresses and attributes. If - * {@link ResolutionResult#getAddresses()} is empty, {@link #onError(Status)} will be called. + * {@link ResolutionResult#getAddressesOrError()} is empty, {@link #onError(Status)} will be + * called. + * + *

Newer NameResolver implementations should prefer calling onResult2. This method exists to + * facilitate older {@link Listener} implementations to migrate to {@link Listener2}. * * @param resolutionResult the resolved server addresses, attributes, and Service Config. * @since 1.21.0 @@ -241,6 +292,10 @@ public final void onAddresses( * Handles a name resolving error from the resolver. The listener is responsible for eventually * invoking {@link NameResolver#refresh()} to re-attempt resolution. * + *

New NameResolver implementations should prefer calling onResult2 which will have the + * address resolution error in {@link ResolutionResult}'s addressesOrError. This method exists + * to facilitate older implementations using {@link Listener} to migrate to {@link Listener2}. + * * @param error a non-OK status * @since 1.21.0 */ @@ -248,9 +303,14 @@ public final void onAddresses( public abstract void onError(Status error); /** - * Handles updates on resolved addresses and attributes. + * Handles updates on resolved addresses and attributes. Must be called from the same + * {@link SynchronizationContext} available in {@link NameResolver.Args} that is passed + * from the channel. * - * @param resolutionResult the resolved server addresses, attributes, and Service Config. + * @param resolutionResult the resolved server addresses or error in address resolution, + * attributes, and Service Config or error + * @return status indicating whether the resolutionResult was accepted by the listener, + * typically the result from a load balancer. * @since 1.66 */ public Status onResult2(ResolutionResult resolutionResult) { @@ -268,10 +328,20 @@ public Status onResult2(ResolutionResult resolutionResult) { @Documented public @interface ResolutionResultAttr {} + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11989") + @ResolutionResultAttr + public static final Attributes.Key ATTR_BACKEND_SERVICE = + Attributes.Key.create("io.grpc.NameResolver.ATTR_BACKEND_SERVICE"); + /** * Information that a {@link Factory} uses to create a {@link NameResolver}. * - *

Note this class doesn't override neither {@code equals()} nor {@code hashCode()}. + *

Args applicable to all {@link NameResolver}s are defined here using ordinary setters and + * getters. This container can also hold externally-defined "custom" args that aren't so widely + * useful or that would be inappropriate dependencies for this low level API. See {@link + * Args#getArg} for more. + * + *

Note this class overrides neither {@code equals()} nor {@code hashCode()}. * * @since 1.21.0 */ @@ -285,24 +355,24 @@ public static final class Args { @Nullable private final ChannelLogger channelLogger; @Nullable private final Executor executor; @Nullable private final String overrideAuthority; - - private Args( - Integer defaultPort, - ProxyDetector proxyDetector, - SynchronizationContext syncContext, - ServiceConfigParser serviceConfigParser, - @Nullable ScheduledExecutorService scheduledExecutorService, - @Nullable ChannelLogger channelLogger, - @Nullable Executor executor, - @Nullable String overrideAuthority) { - this.defaultPort = checkNotNull(defaultPort, "defaultPort not set"); - this.proxyDetector = checkNotNull(proxyDetector, "proxyDetector not set"); - this.syncContext = checkNotNull(syncContext, "syncContext not set"); - this.serviceConfigParser = checkNotNull(serviceConfigParser, "serviceConfigParser not set"); - this.scheduledExecutorService = scheduledExecutorService; - this.channelLogger = channelLogger; - this.executor = executor; - this.overrideAuthority = overrideAuthority; + private final MetricRecorder metricRecorder; + @Nullable private final NameResolverRegistry nameResolverRegistry; + @Nullable private final IdentityHashMap, Object> customArgs; + + private Args(Builder builder) { + this.defaultPort = checkNotNull(builder.defaultPort, "defaultPort not set"); + this.proxyDetector = checkNotNull(builder.proxyDetector, "proxyDetector not set"); + this.syncContext = checkNotNull(builder.syncContext, "syncContext not set"); + this.serviceConfigParser = + checkNotNull(builder.serviceConfigParser, "serviceConfigParser not set"); + this.scheduledExecutorService = builder.scheduledExecutorService; + this.channelLogger = builder.channelLogger; + this.executor = builder.executor; + this.overrideAuthority = builder.overrideAuthority; + this.metricRecorder = builder.metricRecorder != null ? builder.metricRecorder + : new MetricRecorder() {}; + this.nameResolverRegistry = builder.nameResolverRegistry; + this.customArgs = cloneCustomArgs(builder.customArgs); } /** @@ -311,6 +381,7 @@ private Args( * * @since 1.21.0 */ + //

TODO: Only meaningful for InetSocketAddress producers. Make this a custom arg? public int getDefaultPort() { return defaultPort; } @@ -363,6 +434,30 @@ public ServiceConfigParser getServiceConfigParser() { return serviceConfigParser; } + /** + * Returns the value of a custom arg named 'key', or {@code null} if it's not set. + * + *

While ordinary {@link Args} should be universally useful and meaningful, custom arguments + * can apply just to resolvers of a certain URI scheme, just to resolvers producing a particular + * type of {@link java.net.SocketAddress}, or even an individual {@link NameResolver} subclass. + * Custom args are identified by an instance of {@link Args.Key} which should be a constant + * defined in a java package and class appropriate for the argument's scope. + * + *

{@link Args} are normally reserved for information in *support* of name resolution, not + * the name to be resolved itself. However, there are rare cases where all or part of the target + * name can't be represented by any standard URI scheme or can't be encoded as a String at all. + * Custom args, in contrast, can hold arbitrary Java types, making them a useful work around in + * these cases. + * + *

Custom args can also be used simply to avoid adding inappropriate deps to the low level + * io.grpc package. + */ + @SuppressWarnings("unchecked") // Cast is safe because all put()s go through the setArg() API. + @Nullable + public T getArg(Key key) { + return customArgs != null ? (T) customArgs.get(key) : null; + } + /** * Returns the {@link ChannelLogger} for the Channel served by this NameResolver. * @@ -400,6 +495,25 @@ public String getOverrideAuthority() { return overrideAuthority; } + /** + * Returns the {@link MetricRecorder} that the channel uses to record metrics. + */ + public MetricRecorder getMetricRecorder() { + return metricRecorder; + } + + /** + * Returns the {@link NameResolverRegistry} that the Channel uses to look for {@link + * NameResolver}s. + * + * @since 1.74.0 + */ + public NameResolverRegistry getNameResolverRegistry() { + if (nameResolverRegistry == null) { + throw new IllegalStateException("NameResolverRegistry is not set in Builder"); + } + return nameResolverRegistry; + } @Override public String toString() { @@ -408,10 +522,13 @@ public String toString() { .add("proxyDetector", proxyDetector) .add("syncContext", syncContext) .add("serviceConfigParser", serviceConfigParser) + .add("customArgs", customArgs) .add("scheduledExecutorService", scheduledExecutorService) .add("channelLogger", channelLogger) .add("executor", executor) .add("overrideAuthority", overrideAuthority) + .add("metricRecorder", metricRecorder) + .add("nameResolverRegistry", nameResolverRegistry) .toString(); } @@ -430,6 +547,9 @@ public Builder toBuilder() { builder.setChannelLogger(channelLogger); builder.setOffloadExecutor(executor); builder.setOverrideAuthority(overrideAuthority); + builder.setMetricRecorder(metricRecorder); + builder.setNameResolverRegistry(nameResolverRegistry); + builder.customArgs = cloneCustomArgs(customArgs); return builder; } @@ -456,6 +576,9 @@ public static final class Builder { private ChannelLogger channelLogger; private Executor executor; private String overrideAuthority; + private MetricRecorder metricRecorder; + private NameResolverRegistry nameResolverRegistry; + private IdentityHashMap, Object> customArgs; Builder() { } @@ -542,16 +665,75 @@ public Builder setOverrideAuthority(String authority) { return this; } + /** See {@link Args#getArg(Key)}. */ + public Builder setArg(Key key, T value) { + checkNotNull(key, "key"); + checkNotNull(value, "value"); + if (customArgs == null) { + customArgs = new IdentityHashMap<>(); + } + customArgs.put(key, value); + return this; + } + + /** + * See {@link Args#getMetricRecorder()}. This is an optional field. + */ + public Builder setMetricRecorder(MetricRecorder metricRecorder) { + this.metricRecorder = checkNotNull(metricRecorder, "metricRecorder"); + return this; + } + + /** + * See {@link Args#getNameResolverRegistry}. This is an optional field. + * + * @since 1.74.0 + */ + public Builder setNameResolverRegistry(NameResolverRegistry registry) { + this.nameResolverRegistry = registry; + return this; + } + /** * Builds an {@link Args}. * * @since 1.21.0 */ public Args build() { - return - new Args( - defaultPort, proxyDetector, syncContext, serviceConfigParser, - scheduledExecutorService, channelLogger, executor, overrideAuthority); + return new Args(this); + } + } + + /** + * Identifies an externally-defined custom argument that can be stored in {@link Args}. + * + *

Uses reference equality so keys should be defined as global constants. + * + * @param type of values that can be stored under this key + */ + @Immutable + @SuppressWarnings("UnusedTypeParameter") + public static final class Key { + private final String debugString; + + private Key(String debugString) { + this.debugString = debugString; + } + + @Override + public String toString() { + return debugString; + } + + /** + * Creates a new instance of {@link Key}. + * + * @param debugString a string used to describe the key, used for debugging. + * @param Key type + * @return a new instance of Key + */ + public static Key create(String debugString) { + return new Key<>(debugString); } } } @@ -584,17 +766,17 @@ public abstract static class ServiceConfigParser { */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1770") public static final class ResolutionResult { - private final List addresses; + private final StatusOr> addressesOrError; @ResolutionResultAttr private final Attributes attributes; @Nullable private final ConfigOrError serviceConfig; ResolutionResult( - List addresses, + StatusOr> addressesOrError, @ResolutionResultAttr Attributes attributes, ConfigOrError serviceConfig) { - this.addresses = Collections.unmodifiableList(new ArrayList<>(addresses)); + this.addressesOrError = addressesOrError; this.attributes = checkNotNull(attributes, "attributes"); this.serviceConfig = serviceConfig; } @@ -615,7 +797,7 @@ public static Builder newBuilder() { */ public Builder toBuilder() { return newBuilder() - .setAddresses(addresses) + .setAddressesOrError(addressesOrError) .setAttributes(attributes) .setServiceConfig(serviceConfig); } @@ -624,9 +806,20 @@ public Builder toBuilder() { * Gets the addresses resolved by name resolution. * * @since 1.21.0 + * @deprecated Will be superseded by getAddressesOrError */ + @Deprecated public List getAddresses() { - return addresses; + return addressesOrError.getValue(); + } + + /** + * Gets the addresses resolved by name resolution or the error in doing so. + * + * @since 1.65.0 + */ + public StatusOr> getAddressesOrError() { + return addressesOrError; } /** @@ -652,11 +845,11 @@ public ConfigOrError getServiceConfig() { @Override public String toString() { - return MoreObjects.toStringHelper(this) - .add("addresses", addresses) - .add("attributes", attributes) - .add("serviceConfig", serviceConfig) - .toString(); + ToStringHelper stringHelper = MoreObjects.toStringHelper(this); + stringHelper.add("addressesOrError", addressesOrError.toString()); + stringHelper.add("attributes", attributes); + stringHelper.add("serviceConfigOrError", serviceConfig); + return stringHelper.toString(); } /** @@ -668,7 +861,7 @@ public boolean equals(Object obj) { return false; } ResolutionResult that = (ResolutionResult) obj; - return Objects.equal(this.addresses, that.addresses) + return Objects.equal(this.addressesOrError, that.addressesOrError) && Objects.equal(this.attributes, that.attributes) && Objects.equal(this.serviceConfig, that.serviceConfig); } @@ -678,7 +871,7 @@ public boolean equals(Object obj) { */ @Override public int hashCode() { - return Objects.hashCode(addresses, attributes, serviceConfig); + return Objects.hashCode(addressesOrError, attributes, serviceConfig); } /** @@ -688,7 +881,8 @@ public int hashCode() { */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1770") public static final class Builder { - private List addresses = Collections.emptyList(); + private StatusOr> addresses = + StatusOr.fromValue(Collections.emptyList()); private Attributes attributes = Attributes.EMPTY; @Nullable private ConfigOrError serviceConfig; @@ -700,9 +894,21 @@ public static final class Builder { * Sets the addresses resolved by name resolution. This field is required. * * @since 1.21.0 + * @deprecated Will be superseded by setAddressesOrError */ + @Deprecated public Builder setAddresses(List addresses) { - this.addresses = addresses; + setAddressesOrError(StatusOr.fromValue(addresses)); + return this; + } + + /** + * Sets the addresses resolved by name resolution or the error in doing so. This field is + * required. + * @param addresses Resolved addresses or an error in resolving addresses + */ + public Builder setAddressesOrError(StatusOr> addresses) { + this.addresses = checkNotNull(addresses, "StatusOr addresses cannot be null."); return this; } @@ -825,4 +1031,10 @@ public String toString() { } } } + + @Nullable + private static IdentityHashMap, Object> cloneCustomArgs( + @Nullable IdentityHashMap, Object> customArgs) { + return customArgs != null ? new IdentityHashMap<>(customArgs) : null; + } } diff --git a/api/src/main/java/io/grpc/NameResolverRegistry.java b/api/src/main/java/io/grpc/NameResolverRegistry.java index 23eec23fd6a..c5e9f7467ab 100644 --- a/api/src/main/java/io/grpc/NameResolverRegistry.java +++ b/api/src/main/java/io/grpc/NameResolverRegistry.java @@ -20,6 +20,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.net.URI; import java.util.ArrayList; import java.util.Collections; @@ -28,10 +29,10 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -125,8 +126,10 @@ public static synchronized NameResolverRegistry getDefaultRegistry() { if (instance == null) { List providerList = ServiceProviders.loadAll( NameResolverProvider.class, - getHardCodedClasses(), - NameResolverProvider.class.getClassLoader(), + ServiceLoader + .load(NameResolverProvider.class, NameResolverProvider.class.getClassLoader()) + .iterator(), + NameResolverRegistry::getHardCodedClasses, new NameResolverPriorityAccessor()); if (providerList.isEmpty()) { logger.warning("No NameResolverProviders found via ServiceLoader, including for DNS. This " @@ -166,6 +169,11 @@ static List> getHardCodedClasses() { } catch (ClassNotFoundException e) { logger.log(Level.FINE, "Unable to find DNS NameResolver", e); } + try { + list.add(Class.forName("io.grpc.binder.internal.IntentNameResolverProvider")); + } catch (ClassNotFoundException e) { + logger.log(Level.FINE, "Unable to find IntentNameResolverProvider", e); + } return Collections.unmodifiableList(list); } @@ -177,6 +185,13 @@ public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { return provider == null ? null : provider.newNameResolver(targetUri, args); } + @Override + @Nullable + public NameResolver newNameResolver(io.grpc.Uri targetUri, NameResolver.Args args) { + NameResolverProvider provider = getProviderForScheme(targetUri.getScheme()); + return provider == null ? null : provider.newNameResolver(targetUri, args); + } + @Override public String getDefaultScheme() { return NameResolverRegistry.this.getDefaultScheme(); diff --git a/api/src/main/java/io/grpc/ServerBuilder.java b/api/src/main/java/io/grpc/ServerBuilder.java index cd1cddbb93f..3effe593e57 100644 --- a/api/src/main/java/io/grpc/ServerBuilder.java +++ b/api/src/main/java/io/grpc/ServerBuilder.java @@ -435,6 +435,17 @@ public T setBinaryLog(BinaryLog binaryLog) { */ public abstract Server build(); + /** + * Adds a metric sink to the server. + * + * @param metricSink the metric sink to add. + * @return this + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12693") + public T addMetricSink(MetricSink metricSink) { + return thisT(); + } + /** * Returns the correctly typed version of the builder. */ diff --git a/api/src/main/java/io/grpc/ServerRegistry.java b/api/src/main/java/io/grpc/ServerRegistry.java index a083e45a000..1ec7030b82b 100644 --- a/api/src/main/java/io/grpc/ServerRegistry.java +++ b/api/src/main/java/io/grpc/ServerRegistry.java @@ -18,14 +18,15 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashSet; import java.util.List; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -93,8 +94,9 @@ public static synchronized ServerRegistry getDefaultRegistry() { if (instance == null) { List providerList = ServiceProviders.loadAll( ServerProvider.class, - getHardCodedClasses(), - ServerProvider.class.getClassLoader(), + ServiceLoader.load(ServerProvider.class, ServerProvider.class.getClassLoader()) + .iterator(), + ServerRegistry::getHardCodedClasses, new ServerPriorityAccessor()); instance = new ServerRegistry(); for (ServerProvider provider : providerList) { diff --git a/api/src/main/java/io/grpc/ServiceProviders.java b/api/src/main/java/io/grpc/ServiceProviders.java index ac4b27d8783..861688be9fb 100644 --- a/api/src/main/java/io/grpc/ServiceProviders.java +++ b/api/src/main/java/io/grpc/ServiceProviders.java @@ -17,10 +17,13 @@ package io.grpc; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Supplier; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; +import java.util.Iterator; import java.util.List; +import java.util.ListIterator; import java.util.ServiceConfigurationError; import java.util.ServiceLoader; @@ -29,42 +32,44 @@ private ServiceProviders() { // do not instantiate } - /** - * If this is not Android, returns the highest priority implementation of the class via - * {@link ServiceLoader}. - * If this is Android, returns an instance of the highest priority class in {@code hardcoded}. - */ - public static T load( - Class klass, - Iterable> hardcoded, - ClassLoader cl, - PriorityAccessor priorityAccessor) { - List candidates = loadAll(klass, hardcoded, cl, priorityAccessor); - if (candidates.isEmpty()) { - return null; - } - return candidates.get(0); - } - /** * If this is not Android, returns all available implementations discovered via * {@link ServiceLoader}. * If this is Android, returns all available implementations in {@code hardcoded}. * The list is sorted in descending priority order. + * + *

{@code serviceLoader} should be created with {@code ServiceLoader.load(MyClass.class, + * MyClass.class.getClassLoader()).iterator()} in order to be detected by R8 so that R8 full mode + * will keep the constructors for the provider classes. */ public static List loadAll( Class klass, - Iterable> hardcoded, - ClassLoader cl, + Iterator serviceLoader, + Supplier>> hardcoded, final PriorityAccessor priorityAccessor) { - Iterable candidates; - if (isAndroid(cl)) { - candidates = getCandidatesViaHardCoded(klass, hardcoded); + Iterator candidates; + if (serviceLoader instanceof ListIterator) { + // A rewriting tool has replaced the ServiceLoader with a List of some sort (R8 uses + // ArrayList, AppReduce uses singletonList). We prefer to use such iterators on Android as + // they won't need reflection like the hard-coded list does. In addition, the provider + // instances will have already been created, so it seems we should use them. + // + // R8: https://r8.googlesource.com/r8/+/490bc53d9310d4cc2a5084c05df4aadaec8c885d/src/main/java/com/android/tools/r8/ir/optimize/ServiceLoaderRewriter.java + // AppReduce: service_loader_pass.cc + candidates = serviceLoader; + } else if (isAndroid(klass.getClassLoader())) { + // Avoid getResource() on Android, which must read from a zip which uses a lot of memory + candidates = getCandidatesViaHardCoded(klass, hardcoded.get()).iterator(); + } else if (!serviceLoader.hasNext()) { + // Attempt to load using the context class loader and ServiceLoader. + // This allows frameworks like http://aries.apache.org/modules/spi-fly.html to plug in. + candidates = ServiceLoader.load(klass).iterator(); } else { - candidates = getCandidatesViaServiceLoader(klass, cl); + candidates = serviceLoader; } List list = new ArrayList<>(); - for (T current: candidates) { + while (candidates.hasNext()) { + T current = candidates.next(); if (!priorityAccessor.isAvailable(current)) { continue; } @@ -101,15 +106,14 @@ static boolean isAndroid(ClassLoader cl) { } /** - * Loads service providers for the {@code klass} service using {@link ServiceLoader}. + * For testing only: Loads service providers for the {@code klass} service using {@link + * ServiceLoader}. Does not support spi-fly and related tricks. */ @VisibleForTesting public static Iterable getCandidatesViaServiceLoader(Class klass, ClassLoader cl) { Iterable i = ServiceLoader.load(klass, cl); - // Attempt to load using the context class loader and ServiceLoader. - // This allows frameworks like http://aries.apache.org/modules/spi-fly.html to plug in. if (!i.iterator().hasNext()) { - i = ServiceLoader.load(klass); + return null; } return i; } diff --git a/api/src/main/java/io/grpc/Status.java b/api/src/main/java/io/grpc/Status.java index 5d7dd30df01..38cd9581f8e 100644 --- a/api/src/main/java/io/grpc/Status.java +++ b/api/src/main/java/io/grpc/Status.java @@ -23,6 +23,7 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Objects; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Metadata.TrustedAsciiMarshaller; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -30,7 +31,6 @@ import java.util.Collections; import java.util.List; import java.util.TreeMap; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.annotation.concurrent.Immutable; diff --git a/api/src/main/java/io/grpc/StatusException.java b/api/src/main/java/io/grpc/StatusException.java index b719f881132..c0a67a375b2 100644 --- a/api/src/main/java/io/grpc/StatusException.java +++ b/api/src/main/java/io/grpc/StatusException.java @@ -25,7 +25,9 @@ */ public class StatusException extends Exception { private static final long serialVersionUID = -660954903976144640L; + @SuppressWarnings("serial") // https://github.com/grpc/grpc-java/issues/1913 private final Status status; + @SuppressWarnings("serial") private final Metadata trailers; /** @@ -44,12 +46,7 @@ public StatusException(Status status) { * @since 1.0.0 */ public StatusException(Status status, @Nullable Metadata trailers) { - this(status, trailers, /*fillInStackTrace=*/ true); - } - - StatusException(Status status, @Nullable Metadata trailers, boolean fillInStackTrace) { - super(Status.formatThrowableMessage(status), status.getCause(), - /* enableSuppression */ true, /* writableStackTrace */fillInStackTrace); + super(Status.formatThrowableMessage(status), status.getCause()); this.status = status; this.trailers = trailers; } @@ -68,6 +65,7 @@ public final Status getStatus() { * * @since 1.0.0 */ + @Nullable public final Metadata getTrailers() { return trailers; } diff --git a/api/src/main/java/io/grpc/StatusOr.java b/api/src/main/java/io/grpc/StatusOr.java new file mode 100644 index 00000000000..b7dd68cfd7b --- /dev/null +++ b/api/src/main/java/io/grpc/StatusOr.java @@ -0,0 +1,111 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.base.MoreObjects; +import com.google.common.base.MoreObjects.ToStringHelper; +import com.google.common.base.Objects; +import javax.annotation.Nullable; + +/** Either a Status or a value. */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11563") +public class StatusOr { + private StatusOr(Status status, T value) { + this.status = status; + this.value = value; + } + + /** Construct from a value. */ + public static StatusOr fromValue(T value) { + StatusOr result = new StatusOr(null, value); + return result; + } + + /** Construct from a non-Ok status. */ + public static StatusOr fromStatus(Status status) { + StatusOr result = new StatusOr(checkNotNull(status, "status"), null); + checkArgument(!status.isOk(), "cannot use OK status: %s", status); + return result; + } + + /** Returns whether there is a value. */ + public boolean hasValue() { + return status == null; + } + + /** + * Returns the value if set or throws exception if there is no value set. This method is meant + * to be called after checking the return value of hasValue() first. + */ + public T getValue() { + if (status != null) { + throw new IllegalStateException("No value present."); + } + return value; + } + + /** Returns the status. If there is a value (which can be null), returns OK. */ + public Status getStatus() { + return status == null ? Status.OK : status; + } + + /** + * Note that StatusOr containing statuses, the equality comparision is delegated to + * {@link Status#equals} which just does a reference equality check because equality on + * Statuses is not well defined. + * Instead, do comparison based on their Code with {@link Status#getCode}. The description and + * cause of the Status are unlikely to be stable, and additional fields may be added to Status + * in the future. + */ + @Override + public boolean equals(Object other) { + if (!(other instanceof StatusOr)) { + return false; + } + StatusOr otherStatus = (StatusOr) other; + if (hasValue() != otherStatus.hasValue()) { + return false; + } + if (hasValue()) { + return Objects.equal(value, otherStatus.value); + } + return Objects.equal(status, otherStatus.status); + } + + @Override + public int hashCode() { + return Objects.hashCode(status, value); + } + + @Override + public String toString() { + ToStringHelper stringHelper = MoreObjects.toStringHelper(this); + if (status == null) { + stringHelper.add("value", value); + } else { + stringHelper.add("error", status); + } + return stringHelper.toString(); + } + + @Nullable + private final Status status; + private final T value; +} diff --git a/api/src/main/java/io/grpc/StatusRuntimeException.java b/api/src/main/java/io/grpc/StatusRuntimeException.java index 70c4d10f0b2..ebcc2f0d671 100644 --- a/api/src/main/java/io/grpc/StatusRuntimeException.java +++ b/api/src/main/java/io/grpc/StatusRuntimeException.java @@ -26,11 +26,13 @@ public class StatusRuntimeException extends RuntimeException { private static final long serialVersionUID = 1950934672280720624L; + @SuppressWarnings("serial") // https://github.com/grpc/grpc-java/issues/1913 private final Status status; + @SuppressWarnings("serial") private final Metadata trailers; /** - * Constructs the exception with both a status. See also {@link Status#asRuntimeException()}. + * Constructs the exception with a status. See also {@link Status#asRuntimeException()}. * * @since 1.0.0 */ @@ -45,12 +47,7 @@ public StatusRuntimeException(Status status) { * @since 1.0.0 */ public StatusRuntimeException(Status status, @Nullable Metadata trailers) { - this(status, trailers, /*fillInStackTrace=*/ true); - } - - StatusRuntimeException(Status status, @Nullable Metadata trailers, boolean fillInStackTrace) { - super(Status.formatThrowableMessage(status), status.getCause(), - /* enable suppressions */ true, /* writableStackTrace */ fillInStackTrace); + super(Status.formatThrowableMessage(status), status.getCause()); this.status = status; this.trailers = trailers; } diff --git a/api/src/main/java/io/grpc/SynchronizationContext.java b/api/src/main/java/io/grpc/SynchronizationContext.java index 5a7677ac15f..94916a1b473 100644 --- a/api/src/main/java/io/grpc/SynchronizationContext.java +++ b/api/src/main/java/io/grpc/SynchronizationContext.java @@ -18,8 +18,10 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import static io.grpc.TimeUtils.convertToNanos; import java.lang.Thread.UncaughtExceptionHandler; +import java.time.Duration; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; @@ -162,6 +164,12 @@ public String toString() { return new ScheduledHandle(runnable, future); } + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11657") + public final ScheduledHandle schedule( + final Runnable task, Duration delay, ScheduledExecutorService timerService) { + return schedule(task, convertToNanos(delay), TimeUnit.NANOSECONDS, timerService); + } + /** * Schedules a task to be added and run via {@link #execute} after an initial delay and then * repeated after the delay until cancelled. @@ -193,6 +201,14 @@ public String toString() { return new ScheduledHandle(runnable, future); } + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11657") + public final ScheduledHandle scheduleWithFixedDelay( + final Runnable task, Duration initialDelay, Duration delay, + ScheduledExecutorService timerService) { + return scheduleWithFixedDelay(task, convertToNanos(initialDelay), convertToNanos(delay), + TimeUnit.NANOSECONDS, timerService); + } + private static class ManagedRunnable implements Runnable { final Runnable task; @@ -246,4 +262,4 @@ public boolean isPending() { return !(runnable.hasStarted || runnable.isCancelled); } } -} +} \ No newline at end of file diff --git a/api/src/main/java/io/grpc/TimeUtils.java b/api/src/main/java/io/grpc/TimeUtils.java new file mode 100644 index 00000000000..01b8c158822 --- /dev/null +++ b/api/src/main/java/io/grpc/TimeUtils.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.time.Duration; + +final class TimeUtils { + private TimeUtils() {} + + @IgnoreJRERequirement + static long convertToNanos(Duration duration) { + try { + return duration.toNanos(); + } catch (ArithmeticException tooBig) { + return duration.isNegative() ? Long.MIN_VALUE : Long.MAX_VALUE; + } + } +} diff --git a/api/src/main/java/io/grpc/Uri.java b/api/src/main/java/io/grpc/Uri.java new file mode 100644 index 00000000000..9f8a5a87848 --- /dev/null +++ b/api/src/main/java/io/grpc/Uri.java @@ -0,0 +1,1143 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.common.net.InetAddresses; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import java.net.InetAddress; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.CodingErrorAction; +import java.nio.charset.MalformedInputException; +import java.nio.charset.StandardCharsets; +import java.util.BitSet; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * A not-quite-general-purpose representation of a Uniform Resource Identifier (URI), as defined by + * RFC 3986. + * + *

The URI

+ * + *

A URI identifies a resource by its name or location or both. The resource could be a file, + * service, or some other abstract entity. + * + *

Examples

+ * + *
    + *
  • http://admin@example.com:8080/controlpanel?filter=users#settings + *
  • ftp://[2001:db8::7]/docs/report.pdf + *
  • file:///My%20Computer/Documents/letter.doc + *
  • dns://8.8.8.8/storage.googleapis.com + *
  • mailto:John.Doe@example.com + *
  • tel:+1-206-555-1212 + *
  • urn:isbn:978-1492082798 + *
+ * + *

Limitations

+ * + *

This class aims to meet the needs of grpc-java itself and RPC related code that depend on it. + * It isn't quite general-purpose. It definitely would not be suitable for building an HTTP user + * agent or proxy server. In particular, it: + * + *

    + *
  • Can only represent a URI, not a "URI-reference" or "relative reference". In other words, a + * "scheme" is always required. + *
  • Has no knowledge of the particulars of any scheme, with respect to normalization and + * comparison. We don't know https://google.com is the same as + * https://google.com:443, that file:/// is the same as + * file://localhost, or that joe@example.com is the same as + * joe@EXAMPLE.COM. No one class can or should know everything about every scheme so + * all this is better handled at a higher layer. + *
  • Implements {@link #equals(Object)} as a char-by-char comparison. Expect false negatives. + *
  • Does not support "IPvFuture" literal addresses. + *
  • Does not reflect how web browsers parse user input or the URL Living Standard. + *
  • Does not support different character encodings. Assumes UTF-8 in several places. + *
+ * + *

Migrating from RFC 2396 and {@link java.net.URI}

+ * + *

Those migrating from {@link java.net.URI} and/or its primary specification in RFC 2396 should + * note some differences. + * + *

Uniform Hierarchical Syntax

+ * + *

RFC 3986 unifies the older ideas of "hierarchical" and "opaque" URIs into a single generic + * syntax. What RFC 2396 called an opaque "scheme-specific part" is always broken out by RFC 3986 + * into an authority and path hierarchy, followed by query and fragment components. Accordingly, + * this class has only getters for those components but no {@link + * java.net.URI#getSchemeSpecificPart()} analog. + * + *

The RFC 3986 definition of path is now more liberal to accommodate this: + * + *

    + *
  • Path doesn't have to start with a slash. For example, the path of + * urn:isbn:978-1492082798 is isbn:978-1492082798 even though it doesn't + * look much like a file system path. + *
  • The path can now be empty. So Android's + * intent:#Intent;action=MAIN;category=LAUNCHER;end is now a valid {@link Uri}. Even + * the scheme-only about: is now valid. + *
+ * + *

The uniform syntax always understands what follows a '?' to be a query string. For example, + * mailto:me@example.com?subject=foo now has a query component whereas RFC 2396 + * considered everything after the mailto: scheme to be opaque. + * + *

Same goes for fragment. data:image/png;...#xywh=0,0,10,10 now has a fragment + * whereas RFC 2396 considered everything after the scheme to be opaque. + * + *

Uniform Authority Syntax

+ * + *

RFC 2396 tried to guess if an authority was a "server" (host:port) or "registry-based" + * (arbitrary string) based on its contents. RFC 3986 expects every authority to look like + * [userinfo@]host[:port] and loosens the definition of a "host" to accommodate. Accordingly, this + * class has no equivalent to {@link java.net.URI#parseServerAuthority()} -- authority was parsed + * into its components and checked for validity when the {@link Uri} was created. + * + *

Other Specific Differences

+ * + *

RFC 2396 does not allow underscores in a host name, meaning {@link java.net.URI} switches to + * opaque mode when it sees one. {@link Uri} does allow underscores in host, to accommodate + * registries other than DNS. So http://my_site.com:8080/index.html now parses as a + * host, port and path rather than a single opaque scheme-specific part. + * + *

{@link Uri} strictly *requires* square brackets in the query string and fragment to be + * percent-encoded whereas RFC 2396 merely recommended doing so. + * + *

Other URx classes are "liberal in what they accept and strict in what they produce." {@link + * Uri#parse(String)} and {@link Uri#create(String)}, however, are strict in what they accept and + * transparent when asked to reproduce it via {@link Uri#toString()}. The former policy may be + * appropriate for parsing user input or web content, but this class is meant for gRPC clients, + * servers and plugins like name resolvers where human error at runtime is less likely and best + * detected early. {@link java.net.URI#create(String)} is similarly strict, which makes migration + * easy, except for the server/registry-based ambiguity addressed by {@link + * java.net.URI#parseServerAuthority()}. + * + *

{@link java.net.URI} and {@link Uri} both support IPv6 literals in square brackets as defined + * by RFC 2732. + * + *

{@link java.net.URI} supports IPv6 scope IDs but accepts and emits a non-standard syntax. + * {@link Uri} implements the newer RFC 6874, which percent encodes scope IDs and the % delimiter + * itself. RFC 9844 claims to obsolete RFC 6874 because web browsers would not support it. This + * class implements RFC 6874 anyway, mostly to avoid creating a barrier to migration away from + * {@link java.net.URI}. + * + *

Some URI components, e.g. scheme, are required while others may or may not be present, e.g. + * authority. {@link Uri} is careful to preserve the distinction between an absent string component + * (getter returns null) and one with an empty value (getter returns ""). {@link java.net.URI} makes + * this distinction too, *except* when it comes to the authority and host components: {@link + * java.net.URI#getAuthority()} and {@link java.net.URI#getHost()} return null when an authority is + * absent, e.g. file:/path as expected. But these methods surprisingly also return null + * when the authority is the empty string, e.g.file:///path. {@link Uri}'s getters + * correctly return null and "" in these cases, respectively, as one would expect. + */ +@Internal +public final class Uri { + // Components are stored percent-encoded, just as originally parsed for transparent parse/toString + // round-tripping. + private final String scheme; // != null since we don't support relative references. + @Nullable private final String userInfo; + @Nullable private final String host; + @Nullable private final String port; + private final String path; // In RFC 3986, path is always defined (but can be empty). + @Nullable private final String query; + @Nullable private final String fragment; + + private Uri(Builder builder) { + this.scheme = checkNotNull(builder.scheme, "scheme"); + this.userInfo = builder.userInfo; + this.host = builder.host; + this.port = builder.port; + this.path = builder.path; + this.query = builder.query; + this.fragment = builder.fragment; + + // Checks common to the parse() and Builder code paths. + if (hasAuthority()) { + if (!path.isEmpty() && !path.startsWith("/")) { + throw new IllegalArgumentException("Has authority -- Non-empty path must start with '/'"); + } + } else { + if (path.startsWith("//")) { + throw new IllegalArgumentException("No authority -- Path cannot start with '//'"); + } + } + } + + /** + * Parses a URI from its string form. + * + * @throws URISyntaxException if 's' is not a valid RFC 3986 URI. + */ + public static Uri parse(String s) throws URISyntaxException { + try { + return create(s); + } catch (IllegalArgumentException e) { + throw new URISyntaxException(s, e.getMessage()); + } + } + + /** + * Creates a URI from a string assumed to be valid. + * + *

Useful for defining URI constants in code. Not for user input. + * + * @throws IllegalArgumentException if 's' is not a valid RFC 3986 URI. + */ + public static Uri create(String s) { + Builder builder = new Builder(); + int i = 0; + final int n = s.length(); + + // 3.1. Scheme: Look for a ':' before '/', '?', or '#'. + int schemeColon = -1; + for (; i < n; ++i) { + char c = s.charAt(i); + if (c == ':') { + schemeColon = i; + break; + } else if (c == '/' || c == '?' || c == '#') { + break; + } + } + if (schemeColon < 0) { + throw new IllegalArgumentException("Missing required scheme."); + } + builder.setRawScheme(s.substring(0, schemeColon)); + + // 3.2. Authority. Look for '//' then keep scanning until '/', '?', or '#'. + i = schemeColon + 1; + if (i + 1 < n && s.charAt(i) == '/' && s.charAt(i + 1) == '/') { + // "//" just means we have an authority. Skip over it. + i += 2; + + int authorityStart = i; + for (; i < n; ++i) { + char c = s.charAt(i); + if (c == '/' || c == '?' || c == '#') { + break; + } + } + builder.setRawAuthority(s.substring(authorityStart, i)); + } + + // 3.3. Path: Whatever is left before '?' or '#'. + int pathStart = i; + for (; i < n; ++i) { + char c = s.charAt(i); + if (c == '?' || c == '#') { + break; + } + } + builder.setRawPath(s.substring(pathStart, i)); + + // 3.4. Query, if we stopped at '?'. + if (i < n && s.charAt(i) == '?') { + i++; // Skip '?' + int queryStart = i; + for (; i < n; ++i) { + char c = s.charAt(i); + if (c == '#') { + break; + } + } + builder.setRawQuery(s.substring(queryStart, i)); + } + + // 3.5. Fragment, if we stopped at '#'. + if (i < n && s.charAt(i) == '#') { + ++i; // Skip '#' + builder.setRawFragment(s.substring(i)); + } + + return builder.build(); + } + + private static int findPortStartColon(String authority, int hostStart) { + for (int i = authority.length() - 1; i >= hostStart; --i) { + char c = authority.charAt(i); + if (c == ':') { + return i; + } + if (c == ']') { + // Hit the end of IP-literal. Any further colon is inside it and couldn't indicate a port. + break; + } + if (!digitChars.get(c)) { + // Found a non-digit, non-colon, non-bracket. + // This means there is no valid port (e.g. host is "example.com") + break; + } + } + return -1; + } + + // Checks a raw path for validity and parses it into segments. Let 'out' be null to just validate. + private static void parseAssumedUtf8PathIntoSegments( + String path, ImmutableList.Builder out) { + // Skip the first slash so it doesn't count as an empty segment at the start. + // (e.g., "/a" -> ["a"], not ["", "a"]) + int start = path.startsWith("/") ? 1 : 0; + + for (int i = start; i < path.length(); ) { + int nextSlash = path.indexOf('/', i); + String segment; + if (nextSlash >= 0) { + // Typical segment case (e.g., "foo" in "/foo/bar"). + segment = path.substring(i, nextSlash); + i = nextSlash + 1; + } else { + // Final segment case (e.g., "bar" in "/foo/bar"). + segment = path.substring(i); + i = path.length(); + } + if (out != null) { + out.add(percentDecodeAssumedUtf8(segment)); + } else { + checkPercentEncodedArg(segment, "path segment", pChars); + } + } + + // RFC 3986 says a trailing slash creates a final empty segment. + // (e.g., "/foo/" -> ["foo", ""]) + if (path.endsWith("/") && out != null) { + out.add(""); + } + } + + /** Returns the scheme of this URI. */ + public String getScheme() { + return scheme; + } + + /** + * Returns the percent-decoded "Authority" component of this URI, or null if not present. + * + *

NB: This method's decoding is lossy -- It only exists for compatibility with {@link + * java.net.URI}. Prefer {@link #getRawAuthority()} or work instead with authority in terms of its + * individual components ({@link #getUserInfo()}, {@link #getHost()} and {@link #getPort()}). The + * problem with getAuthority() is that it returns the delimited concatenation of the percent- + * decoded userinfo, host and port components. But both userinfo and host can contain the '@' + * character, which becomes indistinguishable from the userinfo/host delimiter after decoding. For + * example, URIs scheme://x@y%40z and scheme://x%40y@z have different + * userinfo and host components but getAuthority() returns "x@y@z" for both of them. + * + *

NB: This method assumes the "host" component was encoded as UTF-8, as mandated by RFC 3986. + * This method also assumes the "user information" part of authority was encoded as UTF-8, + * although RFC 3986 doesn't specify an encoding. + * + *

Decoding errors are indicated by a {@code '\u005CuFFFD'} unicode replacement character in + * the output. Callers who want to detect and handle errors in some other way should call {@link + * #getRawAuthority()}, {@link #percentDecode(CharSequence)}, then decode the bytes for + * themselves. + */ + @Nullable + public String getAuthority() { + return percentDecodeAssumedUtf8(getRawAuthority()); + } + + private boolean hasAuthority() { + return host != null; + } + + /** + * Returns the "authority" component of this URI in its originally parsed, possibly + * percent-encoded form. + */ + @Nullable + public String getRawAuthority() { + if (hasAuthority()) { + StringBuilder sb = new StringBuilder(); + appendAuthority(sb); + return sb.toString(); + } + return null; + } + + private void appendAuthority(StringBuilder sb) { + if (userInfo != null) { + sb.append(userInfo).append('@'); + } + if (host != null) { + sb.append(host); + } + if (port != null) { + sb.append(':').append(port); + } + } + + /** + * Returns the percent-decoded "User Information" component of this URI, or null if not present. + * + *

NB: This method *assumes* this component was encoded as UTF-8, although RFC 3986 doesn't + * specify an encoding. + * + *

Decoding errors are indicated by a {@code '\u005CuFFFD'} unicode replacement character in + * the output. Callers who want to detect and handle errors in some other way should call {@link + * #getRawUserInfo()}, {@link #percentDecode(CharSequence)}, then decode the bytes for themselves. + */ + @Nullable + public String getUserInfo() { + return percentDecodeAssumedUtf8(userInfo); + } + + /** + * Returns the "User Information" component of this URI in its originally parsed, possibly + * percent-encoded form. + */ + @Nullable + public String getRawUserInfo() { + return userInfo; + } + + /** + * Returns the percent-decoded "host" component of this URI, or null if not present. + * + *

This method assumes the host was encoded as UTF-8, as mandated by RFC 3986. + * + *

Decoding errors are indicated by a {@code '\u005CuFFFD'} unicode replacement character in + * the output. Callers who want to detect and handle errors in some other way should call {@link + * #getRawHost()}, {@link #percentDecode(CharSequence)}, then decode the bytes for themselves. + */ + @Nullable + public String getHost() { + return percentDecodeAssumedUtf8(host); + } + + /** + * Returns the host component of this URI in its originally parsed, possibly percent-encoded form. + */ + @Nullable + public String getRawHost() { + return host; + } + + /** Returns the "port" component of this URI, or -1 if empty or not present. */ + public int getPort() { + return port != null && !port.isEmpty() ? Integer.parseInt(port) : -1; + } + + /** Returns the raw port component of this URI in its originally parsed form. */ + @Nullable + public String getRawPort() { + return port; + } + + /** + * Returns the (possibly empty) percent-decoded "path" component of this URI. + * + *

NB: This method *assumes* the path was encoded as UTF-8, although RFC 3986 doesn't specify + * an encoding. + * + *

Decoding errors are indicated by a {@code '\u005CuFFFD'} unicode replacement character in + * the output. Callers who want to detect and handle errors in some other way should call {@link + * #getRawPath()}, {@link #percentDecode(CharSequence)}, then decode the bytes for themselves. + * + *

NB: Prefer {@link #getPathSegments()} because this method's decoding is lossy. For example, + * consider these (different) URIs: + * + *

    + *
  • file:///home%2Ffolder/my%20file + *
  • file:///home/folder/my%20file + *
+ * + *

Calling getPath() on each returns the same string: /home/folder/my file. You + * can't tell whether the second '/' character is part of the first path segment or separates the + * first and second path segments. This method only exists to ease migration from {@link + * java.net.URI}. + */ + public String getPath() { + return percentDecodeAssumedUtf8(path); + } + + /** + * Returns this URI's path as a list of path segments not including the '/' segment delimiters. + * + *

Prefer this method over {@link #getPath()} because it preserves the distinction between + * segment separators and literal '/'s within a path segment. + * + *

A trailing '/' delimiter in the path results in the empty string as the last element in the + * returned list. For example, file://localhost/foo/bar/ has path segments + * ["foo", "bar", ""] + * + *

A leading '/' delimiter cannot be detected using this method. For example, both + * dns:example.com and dns:///example.com have the same list of path segments: + * ["example.com"]. Use {@link #isPathAbsolute()} or {@link #isPathRootless()} to + * distinguish these cases. + * + *

The returned list is immutable. + */ + public List getPathSegments() { + // Returned list must be immutable but we intentionally keep guava out of the public API. + ImmutableList.Builder segmentsBuilder = ImmutableList.builder(); + parseAssumedUtf8PathIntoSegments(path, segmentsBuilder); + return segmentsBuilder.build(); + } + + /** + * Returns true iff this URI's path component starts with a path segment (rather than the '/' + * segment delimiter). + * + *

The path of an RFC 3986 URI is either empty, absolute (starts with the '/' segment + * delimiter) or rootless (starts with a path segment). For example, tel:+1-206-555-1212 + * , mailto:me@example.com and urn:isbn:978-1492082798 all have + * rootless paths. mailto:%2Fdev%2Fnull@example.com is also rootless because its + * percent-encoded slashes are not segment delimiters but rather part of the first and only path + * segment. + * + *

Contrast rootless paths with absolute ones (see {@link #isPathAbsolute()}. + */ + public boolean isPathRootless() { + return !path.isEmpty() && !path.startsWith("/"); + } + + /** + * Returns true iff this URI's path component starts with the '/' segment delimiter (rather than a + * path segment). + * + *

The path of an RFC 3986 URI is either empty, absolute (starts with the '/' segment + * delimiter) or rootless (starts with a path segment). For example, file:///resume.txt + * , file:/resume.txt and file://localhost/ all have absolute + * paths while tel:+1-206-555-1212's path is not absolute. + * mailto:%2Fdev%2Fnull@example.com is also not absolute because its percent-encoded + * slashes are not segment delimiters but rather part of the first and only path segment. + * + *

Contrast absolute paths with rootless ones (see {@link #isPathRootless()}. + * + *

NB: The term "absolute" has two different meanings in RFC 3986 which are easily confused. + * This method tests for a property of this URI's path component. Contrast with {@link + * #isAbsolute()} which tests the URI itself for a different property. + */ + public boolean isPathAbsolute() { + return path.startsWith("/"); + } + + /** + * Returns the path component of this URI in its originally parsed, possibly percent-encoded form. + */ + public String getRawPath() { + return path; + } + + /** + * Returns the percent-decoded "query" component of this URI, or null if not present. + * + *

NB: This method assumes the query was encoded as UTF-8, although RFC 3986 doesn't specify an + * encoding. + * + *

Decoding errors are indicated by a {@code '\u005CuFFFD'} unicode replacement character in + * the output. Callers who want to detect and handle errors in some other way should call {@link + * #getRawQuery()}, {@link #percentDecode(CharSequence)}, then decode the bytes for themselves. + */ + @Nullable + public String getQuery() { + return percentDecodeAssumedUtf8(query); + } + + /** + * Returns the query component of this URI in its originally parsed, possibly percent-encoded + * form, without any leading '?' character. + */ + @Nullable + public String getRawQuery() { + return query; + } + + /** + * Returns the percent-decoded "fragment" component of this URI, or null if not present. + * + *

NB: This method assumes the fragment was encoded as UTF-8, although RFC 3986 doesn't specify + * an encoding. + * + *

Decoding errors are indicated by a {@code '\u005CuFFFD'} unicode replacement character in + * the output. Callers who want to detect and handle errors in some other way should call {@link + * #getRawFragment()}, {@link #percentDecode(CharSequence)}, then decode the bytes for themselves. + */ + @Nullable + public String getFragment() { + return percentDecodeAssumedUtf8(fragment); + } + + /** + * Returns the fragment component of this URI in its original, possibly percent-encoded form, and + * without any leading '#' character. + */ + @Nullable + public String getRawFragment() { + return fragment; + } + + /** + * {@inheritDoc} + * + *

If this URI was created by {@link #parse(String)} or {@link #create(String)}, then the + * returned string will match that original input exactly. + */ + @Override + public String toString() { + // https://datatracker.ietf.org/doc/html/rfc3986#section-5.3 + StringBuilder sb = new StringBuilder(); + sb.append(scheme).append(':'); + if (hasAuthority()) { + sb.append("//"); + appendAuthority(sb); + } + sb.append(path); + if (query != null) { + sb.append('?').append(query); + } + if (fragment != null) { + sb.append('#').append(fragment); + } + return sb.toString(); + } + + /** + * Returns true iff this URI has a scheme and an authority/path hierarchy, but no fragment. + * + *

All instances of {@link Uri} are RFC 3986 URIs, not "relative references", so this method is + * equivalent to {@code getFragment() == null}. It mostly exists for compatibility with {@link + * java.net.URI}. + */ + public boolean isAbsolute() { + return scheme != null && fragment == null; + } + + /** + * {@inheritDoc} + * + *

Two instances of {@link Uri} are equal if and only if they have the same string + * representation, which RFC 3986 calls "Simple String Comparison" (6.2.1). Callers with a higher + * layer expectation of equality (e.g. http://some%2Dhost:80/foo/./bar.txt ~= + * http://some-host/foo/bar.txt) will experience false negatives. + */ + @Override + public boolean equals(Object otherObj) { + if (!(otherObj instanceof Uri)) { + return false; + } + Uri other = (Uri) otherObj; + return Objects.equals(scheme, other.scheme) + && Objects.equals(userInfo, other.userInfo) + && Objects.equals(host, other.host) + && Objects.equals(port, other.port) + && Objects.equals(path, other.path) + && Objects.equals(query, other.query) + && Objects.equals(fragment, other.fragment); + } + + @Override + public int hashCode() { + return Objects.hash(scheme, userInfo, host, port, path, query, fragment); + } + + /** Returns a new Builder initialized with the fields of this URI. */ + public Builder toBuilder() { + return new Builder(this); + } + + /** Creates a new {@link Builder} with all fields uninitialized or set to their default values. */ + public static Builder newBuilder() { + return new Builder(); + } + + /** Builder for {@link Uri}. */ + public static final class Builder { + private String scheme; + private String path = ""; + private String query; + private String fragment; + private String userInfo; + private String host; + private String port; + + private Builder() {} + + Builder(Uri prototype) { + this.scheme = prototype.scheme; + this.userInfo = prototype.userInfo; + this.host = prototype.host; + this.port = prototype.port; + this.path = prototype.path; + this.query = prototype.query; + this.fragment = prototype.fragment; + } + + /** + * Sets the scheme, e.g. "https", "dns" or "xds". + * + *

This field is required. + * + * @return this, for fluent building + * @throws IllegalArgumentException if the scheme is invalid. + */ + @CanIgnoreReturnValue + public Builder setScheme(String scheme) { + return setRawScheme(scheme.toLowerCase(Locale.ROOT)); + } + + @CanIgnoreReturnValue + Builder setRawScheme(String scheme) { + if (scheme.isEmpty() || !alphaChars.get(scheme.charAt(0))) { + throw new IllegalArgumentException("Scheme must start with an alphabetic char"); + } + for (int i = 0; i < scheme.length(); i++) { + char c = scheme.charAt(i); + if (!schemeChars.get(c)) { + throw new IllegalArgumentException("Invalid character in scheme at index " + i); + } + } + this.scheme = scheme; + return this; + } + + /** + * Specifies the new URI's path component as a string of zero or more '/' delimited segments. + * + *

Path segments can consist of any string of codepoints. Codepoints that can't be encoded + * literally will be percent-encoded for you. + * + *

If a URI contains an authority component, then the path component must either be empty or + * begin with a slash ("/") character. If a URI does not contain an authority component, then + * the path cannot begin with two slash characters ("//"). + * + *

This method interprets all '/' characters in 'path' as segment delimiters. If any of your + * segments contain literal '/' characters, call {@link #setRawPath(String)} instead. + * + *

See RFC 3986 3.3 + * for more. + * + *

This field is required but can be empty (its default value). + * + * @param path the new path + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setPath(String path) { + checkArgument(path != null, "Path can be empty but not null"); + this.path = percentEncode(path, pCharsAndSlash); + return this; + } + + /** + * Specifies the new URI's path component as a string of zero or more '/' delimited segments. + * + *

Path segments can consist of any string of codepoints but the caller must first percent- + * encode anything other than RFC 3986's "pchar" character class using UTF-8. + * + *

If a URI contains an authority component, then the path component must either be empty or + * begin with a slash ("/") character. If a URI does not contain an authority component, then + * the path cannot begin with two slash characters ("//"). + * + *

This method interprets all '/' characters in 'path' as segment delimiters. If any of your + * segments contain literal '/' characters, you must percent-encode them. + * + *

See RFC 3986 3.3 + * for more. + * + *

This field is required but can be empty (its default value). + * + * @param path the new path, a string consisting of characters from "pchar" + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setRawPath(String path) { + checkArgument(path != null, "Path can be empty but not null"); + parseAssumedUtf8PathIntoSegments(path, null); + this.path = path; + return this; + } + + /** + * Specifies the query component of the new URI (not including the leading '?'). + * + *

Query can contain any string of codepoints. Codepoints that can't be encoded literally + * will be percent-encoded for you as UTF-8. + * + *

This field is optional. + * + * @param query the new query component, or null to clear this field + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setQuery(@Nullable String query) { + this.query = percentEncode(query, queryChars); + return this; + } + + @CanIgnoreReturnValue + Builder setRawQuery(String query) { + checkPercentEncodedArg(query, "query", queryChars); + this.query = query; + return this; + } + + /** + * Specifies the fragment component of the new URI (not including the leading '#'). + * + *

The fragment can contain any string of codepoints. Codepoints that can't be encoded + * literally will be percent-encoded for you as UTF-8. + * + *

This field is optional. + * + * @param fragment the new fragment component, or null to clear this field + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setFragment(@Nullable String fragment) { + this.fragment = percentEncode(fragment, fragmentChars); + return this; + } + + @CanIgnoreReturnValue + Builder setRawFragment(String fragment) { + checkPercentEncodedArg(fragment, "fragment", fragmentChars); + this.fragment = fragment; + return this; + } + + /** + * Set the "user info" component of the new URI, e.g. "username:password", not including the + * trailing '@' character. + * + *

User info can contain any string of codepoints. Codepoints that can't be encoded literally + * will be percent-encoded for you as UTF-8. + * + *

This field is optional. + * + * @param userInfo the new "user info" component, or null to clear this field + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setUserInfo(@Nullable String userInfo) { + this.userInfo = percentEncode(userInfo, userInfoChars); + return this; + } + + @CanIgnoreReturnValue + Builder setRawUserInfo(String userInfo) { + checkPercentEncodedArg(userInfo, "userInfo", userInfoChars); + this.userInfo = userInfo; + return this; + } + + /** + * Specifies the "host" component of the new URI in its "registered name" form (usually DNS), + * e.g. "server.com". + * + *

The registered name can contain any string of codepoints. Codepoints that can't be encoded + * literally will be percent-encoded for you as UTF-8. + * + *

This field is optional. + * + * @param regName the new host component in "registered name" form, or null to clear this field + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setHost(@Nullable String regName) { + if (regName != null) { + regName = regName.toLowerCase(Locale.ROOT); + regName = percentEncode(regName, regNameChars); + } + this.host = regName; + return this; + } + + /** + * Specifies the "host" component of the new URI as an IP address. + * + *

This field is optional. + * + * @param addr the new "host" component in InetAddress form, or null to clear this field + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setHost(@Nullable InetAddress addr) { + this.host = addr != null ? toUriString(addr) : null; + return this; + } + + private static String toUriString(InetAddress addr) { + // InetAddresses.toUriString(addr) is almost enough but neglects RFC 6874 percent encoding. + String inetAddrStr = InetAddresses.toUriString(addr); + int percentIndex = inetAddrStr.indexOf('%'); + if (percentIndex < 0) { + return inetAddrStr; + } + + String scope = inetAddrStr.substring(percentIndex, inetAddrStr.length() - 1); + return inetAddrStr.substring(0, percentIndex) + percentEncode(scope, unreservedChars) + "]"; + } + + @CanIgnoreReturnValue + Builder setRawHost(String host) { + if (host.startsWith("[") && host.endsWith("]")) { + // IP-literal: Guava's isUriInetAddress() is almost enough but it doesn't check the scope. + int percentIndex = host.indexOf('%'); + if (percentIndex > 0) { + String scope = host.substring(percentIndex, host.length() - 1); + checkPercentEncodedArg(scope, "scope", unreservedChars); + } + } + // IP-literal validation is complicated so we delegate it to Guava. We use this particular + // method of InetAddresses because it doesn't try to match interfaces on the local machine. + // (The validity of a URI should be the same no matter which machine does the parsing.) + // TODO(jdcormie): IPFuture + if (!InetAddresses.isUriInetAddress(host)) { + // Must be a "registered name". + checkPercentEncodedArg(host, "host", regNameChars); + } + this.host = host; + return this; + } + + /** + * Specifies the "port" component of the new URI, e.g. "8080". + * + *

The port can be any non-negative integer. A negative value represents "no port". + * + *

This field is optional. + * + * @param port the new "port" component, or -1 to clear this field + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setPort(int port) { + this.port = port < 0 ? null : Integer.toString(port); + return this; + } + + @CanIgnoreReturnValue + Builder setRawPort(String port) { + if (port != null && !port.isEmpty()) { + try { + Integer.parseInt(port); // Result unused. + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid port", e); + } + } + this.port = port; + return this; + } + + /** + * Specifies the userinfo, host and port URI components all at once using a single string. + * + *

This setter is "raw" in the sense that special characters in userinfo and host must be + * passed in percent-encoded. See RFC 3986 3.2 for the set + * of characters allowed in each component of an authority. + * + *

There's no "cooked" method to set authority like for other URI components because + * authority is a *compound* URI component whose userinfo, host and port components are + * delimited with special characters '@' and ':'. But the first two of those components can + * themselves contain these delimiters so we need percent-encoding to parse them unambiguously. + * + * @param authority an RFC 3986 authority string that will be used to set userinfo, host and + * port, or null to clear all three of those components + */ + @CanIgnoreReturnValue + public Builder setRawAuthority(@Nullable String authority) { + if (authority == null) { + setUserInfo(null); + setHost((String) null); + setPort(-1); + } else { + // UserInfo. Easy because '@' cannot appear unencoded inside userinfo or host. + int userInfoEnd = authority.indexOf('@'); + if (userInfoEnd >= 0) { + setRawUserInfo(authority.substring(0, userInfoEnd)); + } else { + setUserInfo(null); + } + + // Host/Port. + int hostStart = userInfoEnd >= 0 ? userInfoEnd + 1 : 0; + int portStartColon = findPortStartColon(authority, hostStart); + if (portStartColon < 0) { + setRawHost(authority.substring(hostStart)); + setPort(-1); + } else { + setRawHost(authority.substring(hostStart, portStartColon)); + setRawPort(authority.substring(portStartColon + 1)); + } + } + return this; + } + + /** Builds a new instance of {@link Uri} as specified by the setters. */ + public Uri build() { + checkState(scheme != null, "Missing required scheme."); + if (host == null) { + checkState(port == null, "Cannot set port without host."); + checkState(userInfo == null, "Cannot set userInfo without host."); + } + return new Uri(this); + } + } + + /** + * Decodes a string of characters in the range [U+0000, U+007F] to bytes. + * + *

Each percent-encoded sequence (e.g. "%F0" or "%2a", as defined by RFC 3986 2.1) is decoded + * to the octet it encodes. Other characters are decoded to their code point's single byte value. + * A literal % character must be encoded as %25. + * + * @throws IllegalArgumentException if 's' contains characters out of range or invalid percent + * encoding sequences. + */ + public static ByteBuffer percentDecode(CharSequence s) { + // This is large enough because each input character needs *at most* one byte of output. + ByteBuffer outBuf = ByteBuffer.allocate(s.length()); + percentDecode(s, "input", null, outBuf); + outBuf.flip(); + return outBuf; + } + + private static void percentDecode( + CharSequence s, String what, BitSet allowedChars, ByteBuffer outBuf) { + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + if (c == '%') { + if (i + 2 >= s.length()) { + throw new IllegalArgumentException( + "Invalid percent-encoding at index " + i + " of " + what + ": " + s); + } + int h1 = Character.digit(s.charAt(i + 1), 16); + int h2 = Character.digit(s.charAt(i + 2), 16); + if (h1 == -1 || h2 == -1) { + throw new IllegalArgumentException( + "Invalid hex digit in " + what + " at index " + i + " of: " + s); + } + if (outBuf != null) { + outBuf.put((byte) (h1 << 4 | h2)); + } + i += 2; + } else if (allowedChars == null || allowedChars.get(c)) { + if (outBuf != null) { + outBuf.put((byte) c); + } + } else { + throw new IllegalArgumentException("Invalid character in " + what + " at index " + i); + } + } + } + + @Nullable + private static String percentDecodeAssumedUtf8(@Nullable String s) { + if (s == null || s.indexOf('%') == -1) { + return s; + } + + ByteBuffer utf8Bytes = percentDecode(s); + try { + return StandardCharsets.UTF_8 + .newDecoder() + .onMalformedInput(CodingErrorAction.REPLACE) + .onUnmappableCharacter(CodingErrorAction.REPLACE) + .decode(utf8Bytes) + .toString(); + } catch (CharacterCodingException e) { + throw new VerifyException(e); // Should not happen in REPLACE mode. + } + } + + @Nullable + private static String percentEncode(String s, BitSet allowedCodePoints) { + if (s == null) { + return null; + } + CharsetEncoder encoder = + StandardCharsets.UTF_8 + .newEncoder() + .onMalformedInput(CodingErrorAction.REPORT) + .onUnmappableCharacter(CodingErrorAction.REPORT); + ByteBuffer utf8Bytes; + try { + utf8Bytes = encoder.encode(CharBuffer.wrap(s)); + } catch (MalformedInputException e) { + throw new IllegalArgumentException("Malformed input", e); // Must be a broken surrogate pair. + } catch (CharacterCodingException e) { + throw new VerifyException(e); // Should not happen when encoding to UTF-8. + } + + StringBuilder sb = new StringBuilder(); + while (utf8Bytes.hasRemaining()) { + int b = 0xff & utf8Bytes.get(); + if (allowedCodePoints.get(b)) { + sb.append((char) b); + } else { + sb.append('%'); + sb.append(hexDigitsByVal[(b & 0xF0) >> 4]); + sb.append(hexDigitsByVal[b & 0x0F]); + } + } + return sb.toString(); + } + + private static void checkPercentEncodedArg(String s, String what, BitSet allowedChars) { + percentDecode(s, what, allowedChars, null); + } + + // See UriTest for how these were computed from the ABNF constants in RFC 3986. + static final BitSet digitChars = BitSet.valueOf(new long[] {0x3ff000000000000L}); + static final BitSet alphaChars = BitSet.valueOf(new long[] {0L, 0x7fffffe07fffffeL}); + // scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) + static final BitSet schemeChars = + BitSet.valueOf(new long[] {0x3ff680000000000L, 0x7fffffe07fffffeL}); + // unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" + static final BitSet unreservedChars = + BitSet.valueOf(new long[] {0x3ff600000000000L, 0x47fffffe87fffffeL}); + // gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" + static final BitSet genDelimsChars = + BitSet.valueOf(new long[] {0x8400800800000000L, 0x28000001L}); + // sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "=" + static final BitSet subDelimsChars = BitSet.valueOf(new long[] {0x28001fd200000000L}); + // reserved = gen-delims / sub-delims + static final BitSet reservedChars = BitSet.valueOf(new long[] {0xac009fda00000000L, 0x28000001L}); + // reg-name = *( unreserved / pct-encoded / sub-delims ) + static final BitSet regNameChars = + BitSet.valueOf(new long[] {0x2bff7fd200000000L, 0x47fffffe87fffffeL}); + // userinfo = *( unreserved / pct-encoded / sub-delims / ":" ) + static final BitSet userInfoChars = + BitSet.valueOf(new long[] {0x2fff7fd200000000L, 0x47fffffe87fffffeL}); + // pchar = unreserved / pct-encoded / sub-delims / ":" / "@" + static final BitSet pChars = + BitSet.valueOf(new long[] {0x2fff7fd200000000L, 0x47fffffe87ffffffL}); + static final BitSet pCharsAndSlash = + BitSet.valueOf(new long[] {0x2fffffd200000000L, 0x47fffffe87ffffffL}); + // query = *( pchar / "/" / "?" ) + static final BitSet queryChars = + BitSet.valueOf(new long[] {0xafffffd200000000L, 0x47fffffe87ffffffL}); + // fragment = *( pchar / "/" / "?" ) + static final BitSet fragmentChars = queryChars; + + private static final char[] hexDigitsByVal = "0123456789ABCDEF".toCharArray(); +} diff --git a/api/src/test/java/io/grpc/CallOptionsTest.java b/api/src/test/java/io/grpc/CallOptionsTest.java index cc90a9799d7..65fb7ff3bf2 100644 --- a/api/src/test/java/io/grpc/CallOptionsTest.java +++ b/api/src/test/java/io/grpc/CallOptionsTest.java @@ -32,6 +32,7 @@ import com.google.common.base.Objects; import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.internal.SerializingExecutor; +import java.time.Duration; import java.util.concurrent.Executor; import org.junit.Test; import org.junit.runner.RunWith; @@ -150,6 +151,15 @@ public void withDeadlineAfter() { assertAbout(deadline()).that(actual).isWithin(10, MILLISECONDS).of(expected); } + @Test + @IgnoreJRERequirement + public void withDeadlineAfterDuration() { + Deadline actual = CallOptions.DEFAULT.withDeadlineAfter(Duration.ofMinutes(1L)).getDeadline(); + Deadline expected = Deadline.after(1, MINUTES); + + assertAbout(deadline()).that(actual).isWithin(10, MILLISECONDS).of(expected); + } + @Test public void toStringMatches_noDeadline_default() { String actual = allSet diff --git a/api/src/test/java/io/grpc/ConfiguratorRegistryTest.java b/api/src/test/java/io/grpc/ConfiguratorRegistryTest.java index e231d13503a..457d5a36e77 100644 --- a/api/src/test/java/io/grpc/ConfiguratorRegistryTest.java +++ b/api/src/test/java/io/grpc/ConfiguratorRegistryTest.java @@ -85,14 +85,12 @@ public static final class StaticTestingClassLoaderGetBeforeSet implements Runnab @Override public void run() { assertThat(ConfiguratorRegistry.getDefaultRegistry().getConfigurators()).isEmpty(); - - try { - ConfiguratorRegistry.getDefaultRegistry() - .setConfigurators(Arrays.asList(new NoopConfigurator())); - fail("should have failed for invoking set call after get is already called"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().isEqualTo("Configurators are already set"); - } + NoopConfigurator noopConfigurator = new NoopConfigurator(); + ConfiguratorRegistry.getDefaultRegistry() + .setConfigurators(Arrays.asList(noopConfigurator)); + assertThat(ConfiguratorRegistry.getDefaultRegistry().getConfigurators()) + .containsExactly(noopConfigurator); + assertThat(InternalConfiguratorRegistry.getConfiguratorsCallCountBeforeSet()).isEqualTo(1); } } diff --git a/api/src/test/java/io/grpc/HttpConnectProxiedSocketAddressTest.java b/api/src/test/java/io/grpc/HttpConnectProxiedSocketAddressTest.java new file mode 100644 index 00000000000..6620a7d413a --- /dev/null +++ b/api/src/test/java/io/grpc/HttpConnectProxiedSocketAddressTest.java @@ -0,0 +1,248 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; + +import com.google.common.testing.EqualsTester; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HttpConnectProxiedSocketAddressTest { + + private final InetSocketAddress proxyAddress = + new InetSocketAddress(InetAddress.getLoopbackAddress(), 8080); + private final InetSocketAddress targetAddress = + InetSocketAddress.createUnresolved("example.com", 443); + + @Test + public void buildWithAllFields() { + Map headers = new HashMap<>(); + headers.put("X-Custom-Header", "custom-value"); + headers.put("Proxy-Authorization", "Bearer token"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .setUsername("user") + .setPassword("pass") + .build(); + + assertThat(address.getProxyAddress()).isEqualTo(proxyAddress); + assertThat(address.getTargetAddress()).isEqualTo(targetAddress); + assertThat(address.getHeaders()).hasSize(2); + assertThat(address.getHeaders()).containsEntry("X-Custom-Header", "custom-value"); + assertThat(address.getHeaders()).containsEntry("Proxy-Authorization", "Bearer token"); + assertThat(address.getUsername()).isEqualTo("user"); + assertThat(address.getPassword()).isEqualTo("pass"); + } + + @Test + public void buildWithoutOptionalFields() { + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .build(); + + assertThat(address.getProxyAddress()).isEqualTo(proxyAddress); + assertThat(address.getTargetAddress()).isEqualTo(targetAddress); + assertThat(address.getHeaders()).isEmpty(); + assertThat(address.getUsername()).isNull(); + assertThat(address.getPassword()).isNull(); + } + + @Test + public void buildWithEmptyHeaders() { + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(Collections.emptyMap()) + .build(); + + assertThat(address.getHeaders()).isEmpty(); + } + + @Test + public void headersAreImmutable() { + Map headers = new HashMap<>(); + headers.put("key1", "value1"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .build(); + + headers.put("key2", "value2"); + + assertThat(address.getHeaders()).hasSize(1); + assertThat(address.getHeaders()).containsEntry("key1", "value1"); + assertThat(address.getHeaders()).doesNotContainKey("key2"); + } + + @Test + public void returnedHeadersAreUnmodifiable() { + Map headers = new HashMap<>(); + headers.put("key", "value"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .build(); + + assertThrows(UnsupportedOperationException.class, + () -> address.getHeaders().put("newKey", "newValue")); + } + + @Test + public void nullHeadersThrowsException() { + assertThrows(NullPointerException.class, + () -> HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(null) + .build()); + } + + @Test + public void equalsAndHashCode() { + Map headers1 = new HashMap<>(); + headers1.put("header", "value"); + + Map headers2 = new HashMap<>(); + headers2.put("header", "value"); + + Map differentHeaders = new HashMap<>(); + differentHeaders.put("different", "header"); + + new EqualsTester() + .addEqualityGroup( + HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers1) + .setUsername("user") + .setPassword("pass") + .build(), + HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers2) + .setUsername("user") + .setPassword("pass") + .build()) + .addEqualityGroup( + HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(differentHeaders) + .setUsername("user") + .setPassword("pass") + .build()) + .addEqualityGroup( + HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .build()) + .testEquals(); + } + + @Test + public void toStringContainsHeaders() { + Map headers = new HashMap<>(); + headers.put("X-Test", "test-value"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .setUsername("user") + .setPassword("secret") + .build(); + + String toString = address.toString(); + assertThat(toString).contains("headers"); + assertThat(toString).contains("X-Test"); + assertThat(toString).contains("hasPassword=true"); + assertThat(toString).doesNotContain("secret"); + } + + @Test + public void toStringWithoutPassword() { + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .build(); + + String toString = address.toString(); + assertThat(toString).contains("hasPassword=false"); + } + + @Test + public void hashCodeDependsOnHeaders() { + Map headers1 = new HashMap<>(); + headers1.put("header", "value1"); + + Map headers2 = new HashMap<>(); + headers2.put("header", "value2"); + + HttpConnectProxiedSocketAddress address1 = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers1) + .build(); + + HttpConnectProxiedSocketAddress address2 = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers2) + .build(); + + assertNotEquals(address1.hashCode(), address2.hashCode()); + } + + @Test + public void multipleHeadersSupported() { + Map headers = new HashMap<>(); + headers.put("X-Header-1", "value1"); + headers.put("X-Header-2", "value2"); + headers.put("X-Header-3", "value3"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .build(); + + assertThat(address.getHeaders()).hasSize(3); + assertThat(address.getHeaders()).containsEntry("X-Header-1", "value1"); + assertThat(address.getHeaders()).containsEntry("X-Header-2", "value2"); + assertThat(address.getHeaders()).containsEntry("X-Header-3", "value3"); + } +} + diff --git a/api/src/test/java/io/grpc/LoadBalancerRegistryTest.java b/api/src/test/java/io/grpc/LoadBalancerRegistryTest.java index 5b348b7adab..690db6622e0 100644 --- a/api/src/test/java/io/grpc/LoadBalancerRegistryTest.java +++ b/api/src/test/java/io/grpc/LoadBalancerRegistryTest.java @@ -40,7 +40,7 @@ public void getClassesViaHardcoded_classesPresent() throws Exception { @Test public void stockProviders() { LoadBalancerRegistry defaultRegistry = LoadBalancerRegistry.getDefaultRegistry(); - assertThat(defaultRegistry.providers()).hasSize(3); + assertThat(defaultRegistry.providers()).hasSize(4); LoadBalancerProvider pickFirst = defaultRegistry.getProvider("pick_first"); assertThat(pickFirst).isInstanceOf(PickFirstLoadBalancerProvider.class); @@ -56,6 +56,12 @@ public void stockProviders() { assertThat(outlierDetection.getClass().getName()).isEqualTo( "io.grpc.util.OutlierDetectionLoadBalancerProvider"); assertThat(roundRobin.getPriority()).isEqualTo(5); + + LoadBalancerProvider randomSubsetting = defaultRegistry.getProvider( + "random_subsetting_experimental"); + assertThat(randomSubsetting.getClass().getName()).isEqualTo( + "io.grpc.util.RandomSubsettingLoadBalancerProvider"); + assertThat(randomSubsetting.getPriority()).isEqualTo(5); } @Test diff --git a/api/src/test/java/io/grpc/LoadBalancerTest.java b/api/src/test/java/io/grpc/LoadBalancerTest.java index 5e9e5cbe816..22fdc220081 100644 --- a/api/src/test/java/io/grpc/LoadBalancerTest.java +++ b/api/src/test/java/io/grpc/LoadBalancerTest.java @@ -64,6 +64,26 @@ public void pickResult_withSubchannelAndTracer() { assertThat(result.isDrop()).isFalse(); } + @Test + public void pickResult_withSubchannelReplacement() { + PickResult result = PickResult.withSubchannel(subchannel, tracerFactory) + .copyWithSubchannel(subchannel2); + assertThat(result.getSubchannel()).isSameInstanceAs(subchannel2); + assertThat(result.getStatus()).isSameInstanceAs(Status.OK); + assertThat(result.getStreamTracerFactory()).isSameInstanceAs(tracerFactory); + assertThat(result.isDrop()).isFalse(); + } + + @Test + public void pickResult_withStreamTracerFactory() { + PickResult result = PickResult.withSubchannel(subchannel) + .copyWithStreamTracerFactory(tracerFactory); + assertThat(result.getSubchannel()).isSameInstanceAs(subchannel); + assertThat(result.getStatus()).isSameInstanceAs(Status.OK); + assertThat(result.getStreamTracerFactory()).isSameInstanceAs(tracerFactory); + assertThat(result.isDrop()).isFalse(); + } + @Test public void pickResult_withNoResult() { PickResult result = PickResult.withNoResult(); diff --git a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java index 30de2477d77..2479e339791 100644 --- a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java +++ b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java @@ -20,17 +20,23 @@ import static org.junit.Assert.fail; import com.google.common.collect.ImmutableSet; +import io.grpc.FlagResetRule; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** Unit tests for {@link ManagedChannelRegistry}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class ManagedChannelRegistryTest { private String target = "testing123"; private ChannelCredentials creds = new ChannelCredentials() { @@ -40,6 +46,20 @@ public ChannelCredentials withoutBearerTokens() { } }; + @Rule public final FlagResetRule flagResetRule = new FlagResetRule(); + + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + + @Before + public void setUp() { + flagResetRule.setFlagForTest(FeatureFlags::setRfc3986UrisEnabled, enableRfc3986UrisParam); + } + @Test public void register_unavailableProviderThrows() { ManagedChannelRegistry reg = new ManagedChannelRegistry(); diff --git a/api/src/test/java/io/grpc/MetadataTest.java b/api/src/test/java/io/grpc/MetadataTest.java index 14ba8ca9b23..a858fff5e5a 100644 --- a/api/src/test/java/io/grpc/MetadataTest.java +++ b/api/src/test/java/io/grpc/MetadataTest.java @@ -16,6 +16,7 @@ package io.grpc; +import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.US_ASCII; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; @@ -24,6 +25,7 @@ import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -37,9 +39,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.Locale; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -49,9 +49,6 @@ @RunWith(JUnit4.class) public class MetadataTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); - private static final Metadata.BinaryMarshaller FISH_MARSHALLER = new Metadata.BinaryMarshaller() { @Override @@ -65,7 +62,7 @@ public Fish parseBytes(byte[] serialized) { } }; - private static class FishStreamMarsaller implements Metadata.BinaryStreamMarshaller { + private static class FishStreamMarshaller implements Metadata.BinaryStreamMarshaller { @Override public InputStream toStream(Fish fish) { return new ByteArrayInputStream(FISH_MARSHALLER.toBytes(fish)); @@ -82,7 +79,7 @@ public Fish parseStream(InputStream stream) { } private static final Metadata.BinaryStreamMarshaller FISH_STREAM_MARSHALLER = - new FishStreamMarsaller(); + new FishStreamMarshaller(); /** A pattern commonly used to avoid unnecessary serialization of immutable objects. */ private static final class FakeFishStream extends InputStream { @@ -121,10 +118,9 @@ public Fish parseStream(InputStream stream) { @Test public void noPseudoHeaders() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid character"); - - Metadata.Key.of(":test-bin", FISH_MARSHALLER); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Metadata.Key.of(":test-bin", FISH_MARSHALLER)); + assertThat(e).hasMessageThat().isEqualTo("Invalid character ':' in key name ':test-bin'"); } @Test @@ -186,8 +182,7 @@ public void testGetAllNoRemove() { Iterator i = metadata.getAll(KEY).iterator(); assertEquals(lance, i.next()); - thrown.expect(UnsupportedOperationException.class); - i.remove(); + assertThrows(UnsupportedOperationException.class, i::remove); } @Test @@ -271,17 +266,15 @@ public void mergeExpands() { @Test public void shortBinaryKeyName() { - thrown.expect(IllegalArgumentException.class); - - Metadata.Key.of("-bin", FISH_MARSHALLER); + assertThrows(IllegalArgumentException.class, () -> Metadata.Key.of("-bin", FISH_MARSHALLER)); } @Test public void invalidSuffixBinaryKeyName() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Binary header is named"); - - Metadata.Key.of("nonbinary", FISH_MARSHALLER); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Metadata.Key.of("nonbinary", FISH_MARSHALLER)); + assertThat(e).hasMessageThat() + .isEqualTo("Binary header is named nonbinary. It must end with -bin"); } @Test @@ -415,7 +408,7 @@ public void streamedValueDifferentMarshaller() { h.put(KEY_STREAMED, salmon); // Get using a different marshaller instance. - Fish fish = h.get(copyKey(KEY_STREAMED, new FishStreamMarsaller())); + Fish fish = h.get(copyKey(KEY_STREAMED, new FishStreamMarshaller())); assertEquals(salmon, fish); } diff --git a/api/src/test/java/io/grpc/MethodDescriptorTest.java b/api/src/test/java/io/grpc/MethodDescriptorTest.java index 9431190984b..e068e0c1108 100644 --- a/api/src/test/java/io/grpc/MethodDescriptorTest.java +++ b/api/src/test/java/io/grpc/MethodDescriptorTest.java @@ -26,9 +26,7 @@ import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.MethodType; import io.grpc.testing.TestMethodDescriptors; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,10 +35,6 @@ */ @RunWith(JUnit4.class) public class MethodDescriptorTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void createMethodDescriptor() { MethodDescriptor descriptor = MethodDescriptor.newBuilder() diff --git a/api/src/test/java/io/grpc/NameResolverRegistryTest.java b/api/src/test/java/io/grpc/NameResolverRegistryTest.java index 2fd23e3a974..76976c3b59b 100644 --- a/api/src/test/java/io/grpc/NameResolverRegistryTest.java +++ b/api/src/test/java/io/grpc/NameResolverRegistryTest.java @@ -33,7 +33,8 @@ /** Unit tests for {@link NameResolverRegistry}. */ @RunWith(JUnit4.class) public class NameResolverRegistryTest { - private final URI uri = URI.create("dns:///localhost"); + private final URI javaNetUri = URI.create("dns:///localhost"); + private final Uri ioGrpcUri = Uri.create("dns:///localhost"); private final NameResolver.Args args = NameResolver.Args.newBuilder() .setDefaultPort(8080) .setProxyDetector(mock(ProxyDetector.class)) @@ -96,43 +97,80 @@ public void getDefaultScheme_noProvider() { } @Test - public void newNameResolver_providerReturnsNull() { + public void newNameResolver_providerReturnsNull_ioGrpcUri() { NameResolverRegistry registry = new NameResolverRegistry(); registry.register( - new BaseProvider(true, 5, "noScheme") { + new BaseProvider(true, 5, ioGrpcUri.getScheme()) { @Override - public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) { - assertThat(passedUri).isSameInstanceAs(uri); + public NameResolver newNameResolver(Uri passedUri, NameResolver.Args passedArgs) { + assertThat(passedUri).isSameInstanceAs(ioGrpcUri); assertThat(passedArgs).isSameInstanceAs(args); return null; } }); - assertThat(registry.asFactory().newNameResolver(uri, args)).isNull(); - assertThat(registry.asFactory().getDefaultScheme()).isEqualTo("noScheme"); + assertThat(registry.asFactory().newNameResolver(ioGrpcUri, args)).isNull(); + assertThat(registry.asFactory().getDefaultScheme()).isEqualTo(ioGrpcUri.getScheme()); } @Test - public void newNameResolver_providerReturnsNonNull() { + public void newNameResolver_providerReturnsNull_javaNetUri() { NameResolverRegistry registry = new NameResolverRegistry(); - registry.register(new BaseProvider(true, 5, uri.getScheme()) { - @Override - public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) { - return null; - } - }); - final NameResolver nr = new NameResolver() { - @Override public String getServiceAuthority() { - throw new UnsupportedOperationException(); - } + registry.register( + new BaseProvider(true, 5, javaNetUri.getScheme()) { + @Override + public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) { + assertThat(passedUri).isSameInstanceAs(javaNetUri); + assertThat(passedArgs).isSameInstanceAs(args); + return null; + } + }); + assertThat(registry.asFactory().newNameResolver(javaNetUri, args)).isNull(); + assertThat(registry.asFactory().getDefaultScheme()).isEqualTo(javaNetUri.getScheme()); + } - @Override public void start(Listener2 listener) { - throw new UnsupportedOperationException(); - } + @Test + public void newNameResolver_providerReturnsNonNull_ioGrpcUri() { + NameResolverRegistry registry = new NameResolverRegistry(); + Uri uri = ioGrpcUri; + registry.register( + new BaseProvider(true, 5, uri.getScheme()) { + @Override + public NameResolver newNameResolver(Uri passedUri, NameResolver.Args passedArgs) { + return null; + } + }); + final NameResolver nr = new DummyNameResolver(); + registry.register( + new BaseProvider(true, 4, uri.getScheme()) { + @Override + public NameResolver newNameResolver(Uri passedUri, NameResolver.Args passedArgs) { + return nr; + } + }); + registry.register( + new BaseProvider(true, 3, uri.getScheme()) { + @Override + public NameResolver newNameResolver(Uri passedUri, NameResolver.Args passedArgs) { + fail("Should not be called"); + throw new AssertionError(); + } + }); + assertThat(registry.asFactory().newNameResolver(uri, args)).isNull(); + assertThat(registry.asFactory().getDefaultScheme()).isEqualTo(uri.getScheme()); + } - @Override public void shutdown() { - throw new UnsupportedOperationException(); - } - }; + @Test + public void newNameResolver_providerReturnsNonNull_javaNetUri() { + NameResolverRegistry registry = new NameResolverRegistry(); + URI uri = javaNetUri; + registry.register( + new BaseProvider(true, 5, uri.getScheme()) { + @Override + public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) { + return null; + } + }); + final NameResolver nr = new DummyNameResolver(); registry.register( new BaseProvider(true, 4, uri.getScheme()) { @Override @@ -153,27 +191,45 @@ public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) } @Test - public void newNameResolver_multipleScheme() { + public void newNameResolver_multipleScheme_ioGrpcUri() { NameResolverRegistry registry = new NameResolverRegistry(); - registry.register(new BaseProvider(true, 5, uri.getScheme()) { - @Override - public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) { - return null; - } - }); - final NameResolver nr = new NameResolver() { - @Override public String getServiceAuthority() { - throw new UnsupportedOperationException(); - } + Uri uri = ioGrpcUri; + registry.register( + new BaseProvider(true, 5, uri.getScheme()) { + @Override + public NameResolver newNameResolver(Uri passedUri, NameResolver.Args passedArgs) { + return null; + } + }); + final NameResolver nr = new DummyNameResolver(); + registry.register( + new BaseProvider(true, 4, "other") { + @Override + public NameResolver newNameResolver(Uri passedUri, NameResolver.Args passedArgs) { + return nr; + } + }); - @Override public void start(Listener2 listener) { - throw new UnsupportedOperationException(); - } + assertThat(registry.asFactory().newNameResolver(uri, args)).isNull(); + assertThat(registry.asFactory().newNameResolver(Uri.create("other:///0.0.0.0:80"), args)) + .isSameInstanceAs(nr); + assertThat(registry.asFactory().newNameResolver(Uri.create("OTHER:///0.0.0.0:80"), args)) + .isSameInstanceAs(nr); + assertThat(registry.asFactory().getDefaultScheme()).isEqualTo("dns"); + } - @Override public void shutdown() { - throw new UnsupportedOperationException(); - } - }; + @Test + public void newNameResolver_multipleScheme_javaNetUri() { + NameResolverRegistry registry = new NameResolverRegistry(); + URI uri = javaNetUri; + registry.register( + new BaseProvider(true, 5, uri.getScheme()) { + @Override + public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) { + return null; + } + }); + final NameResolver nr = new DummyNameResolver(); registry.register( new BaseProvider(true, 4, "other") { @Override @@ -186,16 +242,17 @@ public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) assertThat(registry.asFactory().newNameResolver(URI.create("/0.0.0.0:80"), args)).isNull(); assertThat(registry.asFactory().newNameResolver(URI.create("///0.0.0.0:80"), args)).isNull(); assertThat(registry.asFactory().newNameResolver(URI.create("other:///0.0.0.0:80"), args)) - .isSameInstanceAs(nr); + .isSameInstanceAs(nr); assertThat(registry.asFactory().newNameResolver(URI.create("OTHER:///0.0.0.0:80"), args)) - .isSameInstanceAs(nr); + .isSameInstanceAs(nr); assertThat(registry.asFactory().getDefaultScheme()).isEqualTo("dns"); } @Test public void newNameResolver_noProvider() { NameResolver.Factory factory = new NameResolverRegistry().asFactory(); - assertThat(factory.newNameResolver(uri, args)).isNull(); + assertThat(factory.newNameResolver(javaNetUri, args)).isNull(); + assertThat(factory.newNameResolver(ioGrpcUri, args)).isNull(); assertThat(factory.getDefaultScheme()).isEqualTo("unknown"); } @@ -261,9 +318,31 @@ public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { throw new UnsupportedOperationException(); } + @Override + public NameResolver newNameResolver(Uri targetUri, NameResolver.Args args) { + throw new UnsupportedOperationException(); + } + @Override public String getDefaultScheme() { return scheme == null ? "scheme" + getClass().getSimpleName() : scheme; } } + + private static class DummyNameResolver extends NameResolver { + @Override + public String getServiceAuthority() { + throw new UnsupportedOperationException(); + } + + @Override + public void start(Listener2 listener) { + throw new UnsupportedOperationException(); + } + + @Override + public void shutdown() { + throw new UnsupportedOperationException(); + } + } } diff --git a/api/src/test/java/io/grpc/NameResolverTest.java b/api/src/test/java/io/grpc/NameResolverTest.java index f825de354af..82abe5c7505 100644 --- a/api/src/test/java/io/grpc/NameResolverTest.java +++ b/api/src/test/java/io/grpc/NameResolverTest.java @@ -17,20 +17,47 @@ package io.grpc; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import com.google.common.base.Objects; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.NameResolver.Listener2; +import io.grpc.NameResolver.ResolutionResult; import io.grpc.NameResolver.ServiceConfigParser; import java.lang.Thread.UncaughtExceptionHandler; +import java.net.SocketAddress; +import java.util.Collections; +import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; /** Unit tests for the inner classes in {@link NameResolver}. */ @RunWith(JUnit4.class) public class NameResolverTest { + private static final List ADDRESSES = + Collections.singletonList( + new EquivalentAddressGroup(new FakeSocketAddress("fake-address-1"), Attributes.EMPTY)); + private static final Attributes.Key YOLO_ATTR_KEY = Attributes.Key.create("yolo"); + private static Attributes ATTRIBUTES = + Attributes.newBuilder().set(YOLO_ATTR_KEY, "To be, or not to be?").build(); + private static final NameResolver.Args.Key FOO_ARG_KEY = + NameResolver.Args.Key.create("foo"); + private static final NameResolver.Args.Key BAR_ARG_KEY = + NameResolver.Args.Key.create("bar"); + private static ConfigOrError CONFIG = ConfigOrError.fromConfig("foo"); + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); private final int defaultPort = 293; private final ProxyDetector proxyDetector = mock(ProxyDetector.class); private final SynchronizationContext syncContext = @@ -41,6 +68,9 @@ public class NameResolverTest { private final ChannelLogger channelLogger = mock(ChannelLogger.class); private final Executor executor = Executors.newSingleThreadExecutor(); private final String overrideAuthority = "grpc.io"; + private final MetricRecorder metricRecorder = new MetricRecorder() {}; + private final int customArgValue = 42; + @Mock NameResolver.Listener mockListener; @Test public void args() { @@ -53,6 +83,9 @@ public void args() { assertThat(args.getChannelLogger()).isSameInstanceAs(channelLogger); assertThat(args.getOffloadExecutor()).isSameInstanceAs(executor); assertThat(args.getOverrideAuthority()).isSameInstanceAs(overrideAuthority); + assertThat(args.getMetricRecorder()).isSameInstanceAs(metricRecorder); + assertThat(args.getArg(FOO_ARG_KEY)).isEqualTo(customArgValue); + assertThat(args.getArg(BAR_ARG_KEY)).isNull(); NameResolver.Args args2 = args.toBuilder().build(); assertThat(args2.getDefaultPort()).isEqualTo(defaultPort); @@ -63,6 +96,9 @@ public void args() { assertThat(args2.getChannelLogger()).isSameInstanceAs(channelLogger); assertThat(args2.getOffloadExecutor()).isSameInstanceAs(executor); assertThat(args2.getOverrideAuthority()).isSameInstanceAs(overrideAuthority); + assertThat(args.getMetricRecorder()).isSameInstanceAs(metricRecorder); + assertThat(args.getArg(FOO_ARG_KEY)).isEqualTo(customArgValue); + assertThat(args.getArg(BAR_ARG_KEY)).isNull(); assertThat(args2).isNotSameInstanceAs(args); assertThat(args2).isNotEqualTo(args); @@ -78,6 +114,144 @@ private NameResolver.Args createArgs() { .setChannelLogger(channelLogger) .setOffloadExecutor(executor) .setOverrideAuthority(overrideAuthority) + .setMetricRecorder(metricRecorder) + .setArg(FOO_ARG_KEY, customArgValue) + .build(); + } + + @Test + @SuppressWarnings("deprecation") + public void startOnOldListener_wrapperListener2UsedToStart() { + final Listener2[] listener2 = new Listener2[1]; + NameResolver nameResolver = new NameResolver() { + @Override + public String getServiceAuthority() { + return null; + } + + @Override + public void shutdown() {} + + @Override + public void start(Listener2 listener2Arg) { + listener2[0] = listener2Arg; + } + }; + nameResolver.start(mockListener); + + listener2[0].onResult(ResolutionResult.newBuilder().setAddresses(ADDRESSES) + .setAttributes(ATTRIBUTES).build()); + verify(mockListener).onAddresses(eq(ADDRESSES), eq(ATTRIBUTES)); + listener2[0].onError(Status.CANCELLED); + verify(mockListener).onError(Status.CANCELLED); + } + + @Test + @SuppressWarnings({"deprecation", "InlineMeInliner"}) + public void listener2AddressesToListener2ResolutionResultConversion() { + final ResolutionResult[] resolutionResult = new ResolutionResult[1]; + NameResolver.Listener2 listener2 = new Listener2() { + @Override + public void onResult(ResolutionResult resolutionResultArg) { + resolutionResult[0] = resolutionResultArg; + } + + @Override + public void onError(Status error) {} + }; + + listener2.onAddresses(ADDRESSES, ATTRIBUTES); + + assertThat(resolutionResult[0].getAddressesOrError().getValue()).isEqualTo(ADDRESSES); + assertThat(resolutionResult[0].getAttributes()).isEqualTo(ATTRIBUTES); + } + + @Test + public void resolutionResult_toString_addressesAttributesAndConfig() { + ResolutionResult resolutionResult = ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(ADDRESSES)) + .setAttributes(ATTRIBUTES) + .setServiceConfig(CONFIG) + .build(); + + assertThat(resolutionResult.toString()).isEqualTo( + "ResolutionResult{addressesOrError=StatusOr{value=" + + "[[[FakeSocketAddress-fake-address-1]/{}]]}, attributes={yolo=To be, or not to be?}, " + + "serviceConfigOrError=ConfigOrError{config=foo}}"); + } + + @Test + public void resolutionResult_hashCode() { + ResolutionResult resolutionResult = ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(ADDRESSES)) + .setAttributes(ATTRIBUTES) + .setServiceConfig(CONFIG) .build(); + + assertThat(resolutionResult.hashCode()).isEqualTo( + Objects.hashCode(StatusOr.fromValue(ADDRESSES), ATTRIBUTES, CONFIG)); + } + + @Test + public void startOnOldListener_resolverReportsError() { + final boolean[] onErrorCalled = new boolean[1]; + final Status[] receivedError = new Status[1]; + + NameResolver resolver = new NameResolver() { + @Override + public String getServiceAuthority() { + return "example.com"; + } + + @Override + public void shutdown() { + } + + @Override + public void start(Listener2 listener2) { + ResolutionResult errorResult = ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromStatus( + Status.UNAVAILABLE + .withDescription("DNS resolution failed with UNAVAILABLE"))) + .build(); + + listener2.onResult(errorResult); + } + }; + + NameResolver.Listener listener = new NameResolver.Listener() { + @Override + public void onAddresses( + List servers, + Attributes attributes) { + throw new AssertionError("Called onAddresses on error"); + } + + @Override + public void onError(Status error) { + onErrorCalled[0] = true; + receivedError[0] = error; + } + }; + + resolver.start(listener); + + assertThat(onErrorCalled[0]).isTrue(); + assertThat(receivedError[0].getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(receivedError[0].getDescription()).isEqualTo( + "DNS resolution failed with UNAVAILABLE"); + } + + private static class FakeSocketAddress extends SocketAddress { + final String name; + + FakeSocketAddress(String name) { + this.name = name; + } + + @Override + public String toString() { + return "FakeSocketAddress-" + name; + } } } diff --git a/api/src/test/java/io/grpc/ServerInterceptorsTest.java b/api/src/test/java/io/grpc/ServerInterceptorsTest.java index abfb3540fe4..b84b3838afa 100644 --- a/api/src/test/java/io/grpc/ServerInterceptorsTest.java +++ b/api/src/test/java/io/grpc/ServerInterceptorsTest.java @@ -19,6 +19,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.same; @@ -40,7 +41,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentMatchers; @@ -55,10 +55,6 @@ public class ServerInterceptorsTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Mock private Marshaller requestMarshaller; @@ -111,21 +107,21 @@ public void makeSureExpectedMocksUnused() { public void npeForNullServiceDefinition() { ServerServiceDefinition serviceDef = null; List interceptors = Arrays.asList(); - thrown.expect(NullPointerException.class); - ServerInterceptors.intercept(serviceDef, interceptors); + assertThrows(NullPointerException.class, + () -> ServerInterceptors.intercept(serviceDef, interceptors)); } @Test public void npeForNullInterceptorList() { - thrown.expect(NullPointerException.class); - ServerInterceptors.intercept(serviceDefinition, (List) null); + assertThrows(NullPointerException.class, + () -> ServerInterceptors.intercept(serviceDefinition, (List) null)); } @Test public void npeForNullInterceptor() { List interceptors = Arrays.asList((ServerInterceptor) null); - thrown.expect(NullPointerException.class); - ServerInterceptors.intercept(serviceDefinition, interceptors); + assertThrows(NullPointerException.class, + () -> ServerInterceptors.intercept(serviceDefinition, interceptors)); } @Test diff --git a/api/src/test/java/io/grpc/ServerServiceDefinitionTest.java b/api/src/test/java/io/grpc/ServerServiceDefinitionTest.java index 6a84d640d78..9e43302e210 100644 --- a/api/src/test/java/io/grpc/ServerServiceDefinitionTest.java +++ b/api/src/test/java/io/grpc/ServerServiceDefinitionTest.java @@ -18,14 +18,13 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -52,9 +51,6 @@ public class ServerServiceDefinitionTest { = ServerMethodDefinition.create(method1, methodHandler1); private ServerMethodDefinition methodDef2 = ServerMethodDefinition.create(method2, methodHandler2); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public ExpectedException thrown = ExpectedException.none(); @Test public void noMethods() { @@ -91,9 +87,7 @@ public void addMethod_duplicateName() { ServiceDescriptor sd = new ServiceDescriptor(serviceName, method1); ServerServiceDefinition.Builder ssd = ServerServiceDefinition.builder(sd) .addMethod(method1, methodHandler1); - thrown.expect(IllegalStateException.class); - ssd.addMethod(diffMethod1, methodHandler2) - .build(); + assertThrows(IllegalStateException.class, () -> ssd.addMethod(diffMethod1, methodHandler2)); } @Test @@ -101,8 +95,7 @@ public void buildMisaligned_extraMethod() { ServiceDescriptor sd = new ServiceDescriptor(serviceName); ServerServiceDefinition.Builder ssd = ServerServiceDefinition.builder(sd) .addMethod(methodDef1); - thrown.expect(IllegalStateException.class); - ssd.build(); + assertThrows(IllegalStateException.class, ssd::build); } @Test @@ -110,16 +103,14 @@ public void buildMisaligned_diffMethodInstance() { ServiceDescriptor sd = new ServiceDescriptor(serviceName, method1); ServerServiceDefinition.Builder ssd = ServerServiceDefinition.builder(sd) .addMethod(diffMethod1, methodHandler1); - thrown.expect(IllegalStateException.class); - ssd.build(); + assertThrows(IllegalStateException.class, ssd::build); } @Test public void buildMisaligned_missingMethod() { ServiceDescriptor sd = new ServiceDescriptor(serviceName, method1); ServerServiceDefinition.Builder ssd = ServerServiceDefinition.builder(sd); - thrown.expect(IllegalStateException.class); - ssd.build(); + assertThrows(IllegalStateException.class, ssd::build); } @Test diff --git a/api/src/test/java/io/grpc/ServiceDescriptorTest.java b/api/src/test/java/io/grpc/ServiceDescriptorTest.java index a05858680d5..89bdead3632 100644 --- a/api/src/test/java/io/grpc/ServiceDescriptorTest.java +++ b/api/src/test/java/io/grpc/ServiceDescriptorTest.java @@ -16,17 +16,18 @@ package io.grpc; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import com.google.common.truth.StringSubject; import io.grpc.MethodDescriptor.MethodType; import io.grpc.testing.TestMethodDescriptors; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,32 +37,27 @@ @RunWith(JUnit4.class) public class ServiceDescriptorTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void failsOnNullName() { - thrown.expect(NullPointerException.class); - thrown.expectMessage("name"); - - new ServiceDescriptor(null, Collections.>emptyList()); + List> methods = Collections.emptyList(); + NullPointerException e = assertThrows(NullPointerException.class, + () -> new ServiceDescriptor(null, methods)); + assertThat(e).hasMessageThat().isEqualTo("name"); } @Test public void failsOnNullMethods() { - thrown.expect(NullPointerException.class); - thrown.expectMessage("methods"); - - new ServiceDescriptor("name", (Collection>) null); + NullPointerException e = assertThrows(NullPointerException.class, + () -> new ServiceDescriptor("name", (Collection>) null)); + assertThat(e).hasMessageThat().isEqualTo("methods"); } @Test public void failsOnNullMethod() { - thrown.expect(NullPointerException.class); - thrown.expectMessage("method"); - - new ServiceDescriptor("name", Collections.>singletonList(null)); + List> methods = Collections.singletonList(null); + NullPointerException e = assertThrows(NullPointerException.class, + () -> new ServiceDescriptor("name", methods)); + assertThat(e).hasMessageThat().isEqualTo("method"); } @Test @@ -69,15 +65,17 @@ public void failsOnNonMatchingNames() { List> descriptors = Collections.>singletonList( MethodDescriptor.newBuilder() .setType(MethodType.UNARY) - .setFullMethodName(MethodDescriptor.generateFullMethodName("wrongservice", "method")) + .setFullMethodName(MethodDescriptor.generateFullMethodName("wrongService", "method")) .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) .build()); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("service names"); - - new ServiceDescriptor("name", descriptors); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> new ServiceDescriptor("fooService", descriptors)); + StringSubject error = assertThat(e).hasMessageThat(); + error.contains("service names"); + error.contains("fooService"); + error.contains("wrongService"); } @Test @@ -96,10 +94,9 @@ public void failsOnNonDuplicateNames() { .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) .build()); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("duplicate"); - - new ServiceDescriptor("name", descriptors); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> new ServiceDescriptor("name", descriptors)); + assertThat(e).hasMessageThat().isEqualTo("duplicate name name/method"); } @Test diff --git a/api/src/test/java/io/grpc/ServiceProvidersTest.java b/api/src/test/java/io/grpc/ServiceProvidersTest.java index 7d4388a5bb9..f971ed42646 100644 --- a/api/src/test/java/io/grpc/ServiceProvidersTest.java +++ b/api/src/test/java/io/grpc/ServiceProvidersTest.java @@ -16,6 +16,7 @@ package io.grpc; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -23,12 +24,15 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; import io.grpc.InternalServiceProviders.PriorityAccessor; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.ServiceConfigurationError; +import java.util.ServiceLoader; +import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,7 +40,6 @@ /** Unit tests for {@link ServiceProviders}. */ @RunWith(JUnit4.class) public class ServiceProvidersTest { - private static final List> NO_HARDCODED = Collections.emptyList(); private static final PriorityAccessor ACCESSOR = new PriorityAccessor() { @Override @@ -51,6 +54,19 @@ public int getPriority(ServiceProvidersTestAbstractProvider provider) { }; private final String serviceFile = "META-INF/services/io.grpc.ServiceProvidersTestAbstractProvider"; + private boolean failingHardCodedAccessed; + private final Supplier>> failingHardCoded = new Supplier>>() { + @Override + public Iterable> get() { + failingHardCodedAccessed = true; + throw new AssertionError(); + } + }; + + @After + public void tearDown() { + assertThat(failingHardCodedAccessed).isFalse(); + } @Test public void contextClassLoaderProvider() { @@ -69,8 +85,8 @@ public void contextClassLoaderProvider() { Thread.currentThread().setContextClassLoader(rcll); assertEquals( Available7Provider.class, - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR).getClass()); + load(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR) + .getClass()); } finally { Thread.currentThread().setContextClassLoader(ccl); } @@ -85,8 +101,7 @@ public void noProvider() { serviceFile, "io/grpc/ServiceProvidersTestAbstractProvider-doesNotExist.txt"); Thread.currentThread().setContextClassLoader(cl); - assertNull(ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR)); + assertNull(load(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR)); } finally { Thread.currentThread().setContextClassLoader(ccl); } @@ -98,11 +113,11 @@ public void multipleProvider() throws Exception { "io/grpc/ServiceProvidersTestAbstractProvider-multipleProvider.txt"); assertSame( Available7Provider.class, - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR).getClass()); + load(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR) + .getClass()); - List providers = ServiceProviders.loadAll( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + List providers = loadAll( + ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); assertEquals(3, providers.size()); assertEquals(Available7Provider.class, providers.get(0).getClass()); assertEquals(Available5Provider.class, providers.get(1).getClass()); @@ -116,8 +131,8 @@ public void unavailableProvider() { "io/grpc/ServiceProvidersTestAbstractProvider-unavailableProvider.txt"); assertEquals( Available7Provider.class, - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR).getClass()); + load(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR) + .getClass()); } @Test @@ -125,8 +140,7 @@ public void unknownClassProvider() { ClassLoader cl = new ReplacingClassLoader(getClass().getClassLoader(), serviceFile, "io/grpc/ServiceProvidersTestAbstractProvider-unknownClassProvider.txt"); try { - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + loadAll(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); fail("Exception expected"); } catch (ServiceConfigurationError e) { // noop @@ -140,8 +154,7 @@ public void exceptionSurfacedToCaller_failAtInit() { try { // Even though there is a working provider, if any providers fail then we should fail // completely to avoid returning something unexpected. - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + loadAll(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); fail("Expected exception"); } catch (ServiceConfigurationError expected) { // noop @@ -154,8 +167,7 @@ public void exceptionSurfacedToCaller_failAtPriority() { "io/grpc/ServiceProvidersTestAbstractProvider-failAtPriorityProvider.txt"); try { // The exception should be surfaced to the caller - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + loadAll(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); fail("Expected exception"); } catch (FailAtPriorityProvider.PriorityException expected) { // noop @@ -168,8 +180,7 @@ public void exceptionSurfacedToCaller_failAtAvailable() { "io/grpc/ServiceProvidersTestAbstractProvider-failAtAvailableProvider.txt"); try { // The exception should be surfaced to the caller - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + loadAll(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); fail("Expected exception"); } catch (FailAtAvailableProvider.AvailableException expected) { // noop @@ -244,6 +255,30 @@ class RandomClass {} assertFalse(candidates.iterator().hasNext()); } + private static T load( + Class klass, + Supplier>> hardCoded, + ClassLoader cl, + PriorityAccessor priorityAccessor) { + List candidates = loadAll(klass, hardCoded, cl, priorityAccessor); + if (candidates.isEmpty()) { + return null; + } + return candidates.get(0); + } + + private static List loadAll( + Class klass, + Supplier>> hardCoded, + ClassLoader classLoader, + PriorityAccessor priorityAccessor) { + return ServiceProviders.loadAll( + klass, + ServiceLoader.load(klass, classLoader).iterator(), + hardCoded, + priorityAccessor); + } + private static class BaseProvider extends ServiceProvidersTestAbstractProvider { private final boolean isAvailable; private final int priority; diff --git a/api/src/test/java/io/grpc/StatusExceptionTest.java b/api/src/test/java/io/grpc/StatusExceptionTest.java index dd0d12dccda..410cfb2289a 100644 --- a/api/src/test/java/io/grpc/StatusExceptionTest.java +++ b/api/src/test/java/io/grpc/StatusExceptionTest.java @@ -28,14 +28,6 @@ @RunWith(JUnit4.class) public class StatusExceptionTest { - @Test - public void internalCtorRemovesStack() { - StackTraceElement[] trace = - new StatusException(Status.CANCELLED, null, false) {}.getStackTrace(); - - assertThat(trace).isEmpty(); - } - @Test public void normalCtorKeepsStack() { StackTraceElement[] trace = diff --git a/api/src/test/java/io/grpc/StatusOrTest.java b/api/src/test/java/io/grpc/StatusOrTest.java new file mode 100644 index 00000000000..f63a314a2bb --- /dev/null +++ b/api/src/test/java/io/grpc/StatusOrTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2015 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.fail; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link StatusOr}. **/ +@RunWith(JUnit4.class) +public class StatusOrTest { + + @Test + public void getValue_throwsIfNoValuePresent() { + try { + StatusOr.fromStatus(Status.ABORTED).getValue(); + + fail("Expected exception."); + } catch (IllegalStateException expected) { } + } + + @Test + @SuppressWarnings("TruthIncompatibleType") + public void equals_differentValueTypes() { + assertThat(StatusOr.fromValue(1)).isNotEqualTo(StatusOr.fromValue("1")); + } + + @Test + public void equals_differentValues() { + assertThat(StatusOr.fromValue(1)).isNotEqualTo(StatusOr.fromValue(2)); + } + + @Test + public void equals_sameValues() { + assertThat(StatusOr.fromValue(1)).isEqualTo(StatusOr.fromValue(1)); + } + + @Test + public void equals_differentStatuses() { + assertThat(StatusOr.fromStatus(Status.ABORTED)).isNotEqualTo( + StatusOr.fromStatus(Status.CANCELLED)); + } + + @Test + public void equals_sameStatuses() { + assertThat(StatusOr.fromStatus(Status.ABORTED)).isEqualTo(StatusOr.fromStatus(Status.ABORTED)); + } + + @Test + public void toString_value() { + assertThat(StatusOr.fromValue(1).toString()).isEqualTo("StatusOr{value=1}"); + } + + @Test + public void toString_nullValue() { + assertThat(StatusOr.fromValue(null).toString()).isEqualTo("StatusOr{value=null}"); + } + + @Test + public void toString_errorStatus() { + assertThat(StatusOr.fromStatus(Status.ABORTED).toString()).isEqualTo( + "StatusOr{error=Status{code=ABORTED, description=null, cause=null}}"); + } +} \ No newline at end of file diff --git a/api/src/test/java/io/grpc/StatusRuntimeExceptionTest.java b/api/src/test/java/io/grpc/StatusRuntimeExceptionTest.java index ab20c111254..d965ed86253 100644 --- a/api/src/test/java/io/grpc/StatusRuntimeExceptionTest.java +++ b/api/src/test/java/io/grpc/StatusRuntimeExceptionTest.java @@ -31,7 +31,7 @@ public class StatusRuntimeExceptionTest { @Test public void internalCtorRemovesStack() { StackTraceElement[] trace = - new StatusRuntimeException(Status.CANCELLED, null, false) {}.getStackTrace(); + new InternalStatusRuntimeException(Status.CANCELLED, null) {}.getStackTrace(); assertThat(trace).isEmpty(); } diff --git a/api/src/test/java/io/grpc/SynchronizationContextTest.java b/api/src/test/java/io/grpc/SynchronizationContextTest.java index 3d5e7fa42b9..668f5ae4d6d 100644 --- a/api/src/test/java/io/grpc/SynchronizationContextTest.java +++ b/api/src/test/java/io/grpc/SynchronizationContextTest.java @@ -27,6 +27,7 @@ import com.google.common.util.concurrent.testing.TestingExecutors; import io.grpc.SynchronizationContext.ScheduledHandle; +import java.time.Duration; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; @@ -72,7 +73,7 @@ public void uncaughtException(Thread t, Throwable e) { @Mock private Runnable task3; - + @After public void tearDown() { assertThat(uncaughtErrors).isEmpty(); } @@ -246,6 +247,43 @@ public void schedule() { verify(task1).run(); } + @Test + @IgnoreJRERequirement + public void scheduleDuration() { + MockScheduledExecutorService executorService = new MockScheduledExecutorService(); + ScheduledHandle handle = + syncContext.schedule(task1, Duration.ofSeconds(10), executorService); + + assertThat(executorService.delay) + .isEqualTo(executorService.unit.convert(10, TimeUnit.SECONDS)); + assertThat(handle.isPending()).isTrue(); + verify(task1, never()).run(); + + executorService.command.run(); + + assertThat(handle.isPending()).isFalse(); + verify(task1).run(); + } + + @Test + @IgnoreJRERequirement + public void scheduleWithFixedDelayDuration() { + MockScheduledExecutorService executorService = new MockScheduledExecutorService(); + ScheduledHandle handle = + syncContext.scheduleWithFixedDelay(task1, Duration.ofSeconds(10), + Duration.ofSeconds(10), executorService); + + assertThat(executorService.delay) + .isEqualTo(executorService.unit.convert(10, TimeUnit.SECONDS)); + assertThat(handle.isPending()).isTrue(); + verify(task1, never()).run(); + + executorService.command.run(); + + assertThat(handle.isPending()).isFalse(); + verify(task1).run(); + } + @Test public void scheduleDueImmediately() { MockScheduledExecutorService executorService = new MockScheduledExecutorService(); @@ -357,5 +395,13 @@ static class MockScheduledExecutorService extends ForwardingScheduledExecutorSer this.unit = unit; return future = super.schedule(command, delay, unit); } + + @Override public ScheduledFuture scheduleWithFixedDelay(Runnable command, long intialDelay, + long delay, TimeUnit unit) { + this.command = command; + this.delay = delay; + this.unit = unit; + return future = super.scheduleWithFixedDelay(command, intialDelay, delay, unit); + } } } diff --git a/api/src/test/java/io/grpc/TimeUtilsTest.java b/api/src/test/java/io/grpc/TimeUtilsTest.java new file mode 100644 index 00000000000..728b8512cd7 --- /dev/null +++ b/api/src/test/java/io/grpc/TimeUtilsTest.java @@ -0,0 +1,60 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static org.junit.Assert.assertEquals; + +import java.time.Duration; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link TimeUtils}. */ +@RunWith(JUnit4.class) +@IgnoreJRERequirement +public class TimeUtilsTest { + + @Test + public void testConvertNormalDuration() { + Duration duration = Duration.ofSeconds(10); + long expected = 10 * 1_000_000_000L; + + assertEquals(expected, TimeUtils.convertToNanos(duration)); + } + + @Test + public void testConvertNegativeDuration() { + Duration duration = Duration.ofSeconds(-3); + long expected = -3 * 1_000_000_000L; + + assertEquals(expected, TimeUtils.convertToNanos(duration)); + } + + @Test + public void testConvertTooLargeDuration() { + Duration duration = Duration.ofSeconds(Long.MAX_VALUE / 1_000_000_000L + 1); + + assertEquals(Long.MAX_VALUE, TimeUtils.convertToNanos(duration)); + } + + @Test + public void testConvertTooLargeNegativeDuration() { + Duration duration = Duration.ofSeconds(Long.MIN_VALUE / 1_000_000_000L - 1); + + assertEquals(Long.MIN_VALUE, TimeUtils.convertToNanos(duration)); + } +} diff --git a/api/src/test/java/io/grpc/UriTest.java b/api/src/test/java/io/grpc/UriTest.java new file mode 100644 index 00000000000..a1bd550696f --- /dev/null +++ b/api/src/test/java/io/grpc/UriTest.java @@ -0,0 +1,772 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.junit.Assume.assumeNoException; + +import com.google.common.net.InetAddresses; +import com.google.common.testing.EqualsTester; +import java.net.Inet6Address; +import java.net.URISyntaxException; +import java.net.UnknownHostException; +import java.util.BitSet; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class UriTest { + + @Test + public void parse_allComponents() throws URISyntaxException { + Uri uri = Uri.parse("scheme://user@host:0443/path?query#fragment"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isEqualTo("user@host:0443"); + assertThat(uri.getUserInfo()).isEqualTo("user"); + assertThat(uri.getPort()).isEqualTo(443); + assertThat(uri.getRawPort()).isEqualTo("0443"); + assertThat(uri.getPath()).isEqualTo("/path"); + assertThat(uri.getQuery()).isEqualTo("query"); + assertThat(uri.getFragment()).isEqualTo("fragment"); + assertThat(uri.toString()).isEqualTo("scheme://user@host:0443/path?query#fragment"); + assertThat(uri.isAbsolute()).isFalse(); // Has a fragment. + assertThat(uri.isPathAbsolute()).isTrue(); + assertThat(uri.isPathRootless()).isFalse(); + } + + @Test + public void parse_noAuthority() throws URISyntaxException { + Uri uri = Uri.parse("scheme:/path?query#fragment"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isNull(); + assertThat(uri.getPath()).isEqualTo("/path"); + assertThat(uri.getQuery()).isEqualTo("query"); + assertThat(uri.getFragment()).isEqualTo("fragment"); + assertThat(uri.toString()).isEqualTo("scheme:/path?query#fragment"); + assertThat(uri.isAbsolute()).isFalse(); // Has a fragment. + } + + @Test + public void parse_ipv6Literal_withPort() throws URISyntaxException { + Uri uri = Uri.parse("scheme://[2001:db8::7]:012345"); + assertThat(uri.getAuthority()).isEqualTo("[2001:db8::7]:012345"); + assertThat(uri.getRawHost()).isEqualTo("[2001:db8::7]"); + assertThat(uri.getHost()).isEqualTo("[2001:db8::7]"); + assertThat(uri.getRawPort()).isEqualTo("012345"); + assertThat(uri.getPort()).isEqualTo(12345); + } + + @Test + public void parse_ipv6Literal_noPort() throws URISyntaxException { + Uri uri = Uri.parse("scheme://[2001:db8::7]"); + assertThat(uri.getAuthority()).isEqualTo("[2001:db8::7]"); + assertThat(uri.getRawHost()).isEqualTo("[2001:db8::7]"); + assertThat(uri.getHost()).isEqualTo("[2001:db8::7]"); + assertThat(uri.getRawPort()).isNull(); + assertThat(uri.getPort()).isLessThan(0); + } + + @Test + public void parse_ipv6ScopedLiteral() throws URISyntaxException { + Uri uri = Uri.parse("http://[fe80::1%25eth0]"); + assertThat(uri.getRawHost()).isEqualTo("[fe80::1%25eth0]"); + assertThat(uri.getHost()).isEqualTo("[fe80::1%eth0]"); + } + + @Test + public void parse_ipv6ScopedPercentEncodedLiteral() throws URISyntaxException { + Uri uri = Uri.parse("http://[fe80::1%25foo-bar%2Fblah]"); + assertThat(uri.getRawHost()).isEqualTo("[fe80::1%25foo-bar%2Fblah]"); + assertThat(uri.getHost()).isEqualTo("[fe80::1%foo-bar/blah]"); + } + + @Test + public void parse_noQuery() throws URISyntaxException { + Uri uri = Uri.parse("scheme://authority/path#fragment"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isEqualTo("authority"); + assertThat(uri.getPath()).isEqualTo("/path"); + assertThat(uri.getQuery()).isNull(); + assertThat(uri.getFragment()).isEqualTo("fragment"); + assertThat(uri.toString()).isEqualTo("scheme://authority/path#fragment"); + } + + @Test + public void parse_noFragment() throws URISyntaxException { + Uri uri = Uri.parse("scheme://authority/path?query"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isEqualTo("authority"); + assertThat(uri.getPath()).isEqualTo("/path"); + assertThat(uri.getQuery()).isEqualTo("query"); + assertThat(uri.getFragment()).isNull(); + assertThat(uri.toString()).isEqualTo("scheme://authority/path?query"); + assertThat(uri.isAbsolute()).isTrue(); + } + + @Test + public void parse_emptyPathWithAuthority() throws URISyntaxException { + Uri uri = Uri.parse("scheme://authority"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isEqualTo("authority"); + assertThat(uri.getPath()).isEmpty(); + assertThat(uri.getQuery()).isNull(); + assertThat(uri.getFragment()).isNull(); + assertThat(uri.toString()).isEqualTo("scheme://authority"); + assertThat(uri.isAbsolute()).isTrue(); + assertThat(uri.isPathAbsolute()).isFalse(); + assertThat(uri.isPathRootless()).isFalse(); + } + + @Test + public void parse_rootless() throws URISyntaxException { + Uri uri = Uri.parse("mailto:ceo@company.com?subject=raise"); + assertThat(uri.getScheme()).isEqualTo("mailto"); + assertThat(uri.getAuthority()).isNull(); + assertThat(uri.getPath()).isEqualTo("ceo@company.com"); + assertThat(uri.getQuery()).isEqualTo("subject=raise"); + assertThat(uri.getFragment()).isNull(); + assertThat(uri.toString()).isEqualTo("mailto:ceo@company.com?subject=raise"); + assertThat(uri.isAbsolute()).isTrue(); + assertThat(uri.isPathAbsolute()).isFalse(); + assertThat(uri.isPathRootless()).isTrue(); + } + + @Test + public void parse_emptyPath() throws URISyntaxException { + Uri uri = Uri.parse("scheme:"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isNull(); + assertThat(uri.getPath()).isEmpty(); + assertThat(uri.getQuery()).isNull(); + assertThat(uri.getFragment()).isNull(); + assertThat(uri.toString()).isEqualTo("scheme:"); + assertThat(uri.isAbsolute()).isTrue(); + assertThat(uri.isPathAbsolute()).isFalse(); + assertThat(uri.isPathRootless()).isFalse(); + } + + @Test + public void parse_emptyQuery() { + Uri uri = Uri.create("scheme:?"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getQuery()).isEmpty(); + } + + @Test + public void parse_emptyFragment() { + Uri uri = Uri.create("scheme:#"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getFragment()).isEmpty(); + } + + @Test + public void parse_emptyUserInfo() { + Uri uri = Uri.create("scheme://@host"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isEqualTo("@host"); + assertThat(uri.getHost()).isEqualTo("host"); + assertThat(uri.getUserInfo()).isEmpty(); + assertThat(uri.toString()).isEqualTo("scheme://@host"); + } + + @Test + public void parse_emptyPort() { + Uri uri = Uri.create("scheme://host:"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isEqualTo("host:"); + assertThat(uri.getRawAuthority()).isEqualTo("host:"); + assertThat(uri.getHost()).isEqualTo("host"); + assertThat(uri.getPort()).isEqualTo(-1); + assertThat(uri.getRawPort()).isEqualTo(""); + assertThat(uri.toString()).isEqualTo("scheme://host:"); + } + + @Test + public void parse_invalidScheme_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("1scheme://authority/path")); + assertThat(e).hasMessageThat().contains("Scheme must start with an alphabetic char"); + + e = assertThrows(URISyntaxException.class, () -> Uri.parse(":path")); + assertThat(e).hasMessageThat().contains("Scheme must start with an alphabetic char"); + } + + @Test + public void parse_unTerminatedScheme_throws() { + URISyntaxException e = assertThrows(URISyntaxException.class, () -> Uri.parse("scheme/")); + assertThat(e).hasMessageThat().contains("Missing required scheme"); + + e = assertThrows(URISyntaxException.class, () -> Uri.parse("scheme?")); + assertThat(e).hasMessageThat().contains("Missing required scheme"); + + e = assertThrows(URISyntaxException.class, () -> Uri.parse("scheme#")); + assertThat(e).hasMessageThat().contains("Missing required scheme"); + } + + @Test + public void parse_invalidCharactersInScheme_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("schem e://authority/path")); + assertThat(e).hasMessageThat().contains("Invalid character in scheme"); + } + + @Test + public void parse_unTerminatedAuthority_throws() { + Uri uri = Uri.create("s://auth/"); + assertThat(uri.getAuthority()).isEqualTo("auth"); + uri = Uri.create("s://auth?"); + assertThat(uri.getAuthority()).isEqualTo("auth"); + uri = Uri.create("s://auth#"); + assertThat(uri.getAuthority()).isEqualTo("auth"); + } + + @Test + public void parse_invalidCharactersInUserinfo_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("scheme://u ser@host/path")); + assertThat(e).hasMessageThat().contains("Invalid character in userInfo"); + } + + @Test + public void parse_invalidBackslashInUserinfo_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("http://other.com\\@intended.com")); + assertThat(e).hasMessageThat().contains("Invalid character in userInfo"); + } + + @Test + public void parse_invalidCharactersInHost_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("scheme://h ost/path")); + assertThat(e).hasMessageThat().contains("Invalid character in host"); + } + + @Test + public void parse_invalidBackslashInHost_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("http://other.com\\.intended.com")); + assertThat(e).hasMessageThat().contains("Invalid character in host"); + } + + @Test + public void parse_invalidBackslashScope_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("http://[::1%25foo\\bar]")); + assertThat(e).hasMessageThat().contains("Invalid character in scope"); + } + + @Test + public void parse_invalidCharactersInPort_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("scheme://user@host:8 0/path")); + assertThat(e).hasMessageThat().contains("Invalid character"); + } + + @Test + public void parse_nonAsciiCharacterInPath_throws() throws URISyntaxException { + URISyntaxException e = assertThrows(URISyntaxException.class, () -> Uri.parse("foo:bär")); + assertThat(e).hasMessageThat().contains("Invalid character in path"); + } + + @Test + public void parse_invalidCharactersInPath_throws() { + URISyntaxException e = assertThrows(URISyntaxException.class, () -> Uri.parse("scheme:/p ath")); + assertThat(e).hasMessageThat().contains("Invalid character in path"); + } + + @Test + public void parse_invalidCharactersInQuery_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("scheme://user@host/p?q[]uery")); + assertThat(e).hasMessageThat().contains("Invalid character in query"); + } + + @Test + public void parse_invalidCharactersInFragment_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("scheme://user@host/path#f[]rag")); + assertThat(e).hasMessageThat().contains("Invalid character in fragment"); + } + + @Test + public void parse_nonAsciiCharacterInFragment_throws() throws URISyntaxException { + URISyntaxException e = assertThrows(URISyntaxException.class, () -> Uri.parse("foo:#bär")); + assertThat(e).hasMessageThat().contains("Invalid character in fragment"); + } + + @Test + public void parse_decoding() throws URISyntaxException { + Uri uri = Uri.parse("s://user%2Ename:pass%2Eword@a%2db:1234/p%20ath?q%20uery#f%20ragment"); + assertThat(uri.getAuthority()).isEqualTo("user.name:pass.word@a-b:1234"); + assertThat(uri.getRawAuthority()).isEqualTo("user%2Ename:pass%2Eword@a%2db:1234"); + assertThat(uri.getUserInfo()).isEqualTo("user.name:pass.word"); + assertThat(uri.getRawUserInfo()).isEqualTo("user%2Ename:pass%2Eword"); + assertThat(uri.getHost()).isEqualTo("a-b"); + assertThat(uri.getRawHost()).isEqualTo("a%2db"); + assertThat(uri.getPort()).isEqualTo(1234); + assertThat(uri.getPath()).isEqualTo("/p ath"); + assertThat(uri.getRawPath()).isEqualTo("/p%20ath"); + assertThat(uri.getQuery()).isEqualTo("q uery"); + assertThat(uri.getRawQuery()).isEqualTo("q%20uery"); + assertThat(uri.getFragment()).isEqualTo("f ragment"); + assertThat(uri.getRawFragment()).isEqualTo("f%20ragment"); + } + + @Test + public void parse_decodingNonAscii() throws URISyntaxException { + Uri uri = Uri.parse("s://a/%E2%82%AC"); + assertThat(uri.getPath()).isEqualTo("/€"); + } + + @Test + public void parse_decodingPercent() throws URISyntaxException { + Uri uri = Uri.parse("s://a/p%2520ath?q%25uery#f%25ragment"); + assertThat(uri.getPath()).isEqualTo("/p%20ath"); + assertThat(uri.getQuery()).isEqualTo("q%uery"); + assertThat(uri.getFragment()).isEqualTo("f%ragment"); + } + + @Test + public void parse_invalidPercentEncoding_throws() { + URISyntaxException e = assertThrows(URISyntaxException.class, () -> Uri.parse("s://a/p%2")); + assertThat(e).hasMessageThat().contains("Invalid"); + + e = assertThrows(URISyntaxException.class, () -> Uri.parse("s://a/p%2G")); + assertThat(e).hasMessageThat().contains("Invalid"); + } + + @Test + public void parse_emptyAuthority() { + Uri uri = Uri.create("file:///foo/bar"); + assertThat(uri.getAuthority()).isEmpty(); + assertThat(uri.getHost()).isEmpty(); + assertThat(uri.getUserInfo()).isNull(); + assertThat(uri.getPort()).isEqualTo(-1); + assertThat(uri.getPath()).isEqualTo("/foo/bar"); + } + + @Test + public void parse_pathSegments_empty() throws URISyntaxException { + Uri uri = Uri.create("scheme:"); + assertThat(uri.getPathSegments()).isEmpty(); + } + + @Test + public void parse_pathSegments_root() throws URISyntaxException { + Uri uri = Uri.create("scheme:/"); + assertThat(uri.getPathSegments()).containsExactly(""); + } + + @Test + public void parse_onePathSegment() throws URISyntaxException { + Uri uri = Uri.create("file:/foo"); + assertThat(uri.getPathSegments()).containsExactly("foo"); + } + + @Test + public void parse_onePathSegment_trailingSlash() throws URISyntaxException { + Uri uri = Uri.create("file:/foo/"); + assertThat(uri.getPathSegments()).containsExactly("foo", ""); + } + + @Test + public void parse_onePathSegment_rootless() throws URISyntaxException { + Uri uri = Uri.create("dns:www.example.com"); + assertThat(uri.getPathSegments()).containsExactly("www.example.com"); + assertThat(uri.isPathAbsolute()).isFalse(); + assertThat(uri.isPathRootless()).isTrue(); + } + + @Test + public void parse_twoPathSegments() throws URISyntaxException { + Uri uri = Uri.create("file:/foo/bar"); + assertThat(uri.getPathSegments()).containsExactly("foo", "bar"); + } + + @Test + public void parse_twoPathSegments_rootless() throws URISyntaxException { + Uri uri = Uri.create("file:foo/bar"); + assertThat(uri.getPathSegments()).containsExactly("foo", "bar"); + } + + @Test + public void parse_percentEncodedPathSegment_rootless() throws URISyntaxException { + Uri uri = Uri.create("mailto:%2Fdev%2Fnull@example.com"); + assertThat(uri.getPathSegments()).containsExactly("/dev/null@example.com"); + assertThat(uri.isPathAbsolute()).isFalse(); + assertThat(uri.isPathRootless()).isTrue(); + } + + @Test + public void toString_percentEncoding() throws URISyntaxException { + Uri uri = + Uri.newBuilder() + .setScheme("s") + .setHost("a b") + .setPath("/p ath") + .setQuery("q uery") + .setFragment("f ragment") + .build(); + assertThat(uri.toString()).isEqualTo("s://a%20b/p%20ath?q%20uery#f%20ragment"); + } + + @Test + public void parse_transparentRoundTrip_ipLiteral() { + Uri uri = Uri.create("http://[2001:dB8::7]:080/%4a%4B%2f%2F?%4c%4D#%4e%4F").toBuilder().build(); + assertThat(uri.toString()).isEqualTo("http://[2001:dB8::7]:080/%4a%4B%2f%2F?%4c%4D#%4e%4F"); + + // IPv6 host has non-canonical :: zeros and mixed case hex digits. + assertThat(uri.getRawHost()).isEqualTo("[2001:dB8::7]"); + assertThat(uri.getHost()).isEqualTo("[2001:dB8::7]"); + assertThat(uri.getRawPort()).isEqualTo("080"); // Leading zeros. + assertThat(uri.getPort()).isEqualTo(80); + // Unnecessary and mixed case percent encodings. + assertThat(uri.getRawPath()).isEqualTo("/%4a%4B%2f%2F"); + assertThat(uri.getPathSegments()).containsExactly("JK//"); + assertThat(uri.getRawQuery()).isEqualTo("%4c%4D"); + assertThat(uri.getQuery()).isEqualTo("LM"); + assertThat(uri.getRawFragment()).isEqualTo("%4e%4F"); + assertThat(uri.getFragment()).isEqualTo("NO"); + } + + @Test + public void parse_transparentRoundTrip_regName() { + Uri uri = Uri.create("http://aB%4A%4b:080/%4a%4B%2f%2F?%4c%4D#%4e%4F").toBuilder().build(); + assertThat(uri.toString()).isEqualTo("http://aB%4A%4b:080/%4a%4B%2f%2F?%4c%4D#%4e%4F"); + + // Mixed case literal chars and hex digits. + assertThat(uri.getRawHost()).isEqualTo("aB%4A%4b"); + assertThat(uri.getHost()).isEqualTo("aBJK"); + assertThat(uri.getRawPort()).isEqualTo("080"); // Leading zeros. + assertThat(uri.getPort()).isEqualTo(80); + // Unnecessary and mixed case percent encodings. + assertThat(uri.getRawPath()).isEqualTo("/%4a%4B%2f%2F"); + assertThat(uri.getPathSegments()).containsExactly("JK//"); + assertThat(uri.getRawQuery()).isEqualTo("%4c%4D"); + assertThat(uri.getQuery()).isEqualTo("LM"); + assertThat(uri.getRawFragment()).isEqualTo("%4e%4F"); + assertThat(uri.getFragment()).isEqualTo("NO"); + } + + @Test + public void builder_numericPort() throws URISyntaxException { + Uri uri = Uri.newBuilder().setScheme("scheme").setHost("host").setPort(80).build(); + assertThat(uri.toString()).isEqualTo("scheme://host:80"); + } + + @Test + public void builder_ipv6Literal() throws URISyntaxException { + Uri uri = + Uri.newBuilder() + .setScheme("scheme") + .setHost(InetAddresses.forString("2001:4860:4860::8844")) + .build(); + assertThat(uri.toString()).isEqualTo("scheme://[2001:4860:4860::8844]"); + } + + @Test + public void builder_ipv6ScopedLiteral_numeric() throws UnknownHostException { + Uri uri = + Uri.newBuilder() + .setScheme("http") + // Create an address with a numeric scope_id, which should always be valid. + .setHost( + Inet6Address.getByAddress(null, InetAddresses.forString("fe80::1").getAddress(), 1)) + .build(); + + // We expect the scope ID to be percent encoded. + assertThat(uri.getRawHost()).isEqualTo("[fe80::1%251]"); + assertThat(uri.getHost()).isEqualTo("[fe80::1%1]"); + } + + @Test + public void builder_ipv6ScopedLiteral_named() throws UnknownHostException { + // Unfortunately, there's no Java API to create an Inet6Address with an arbitrary interface- + // scoped name. There's actually no way to hermetically create an Inet6Address with a scope name + // at all! The following address/interface is likely to be present on Linux test runners. + Inet6Address address; + try { + address = (Inet6Address) InetAddresses.forString("::1%lo"); + } catch (IllegalArgumentException e) { + assumeNoException(e); + return; // Not reached. + } + Uri uri = Uri.newBuilder().setScheme("http").setHost(address).build(); + + // We expect the scope ID to be percent encoded. + assertThat(uri.getRawHost()).isEqualTo("[::1%25lo]"); + assertThat(uri.getHost()).isEqualTo("[::1%lo]"); + } + + @Test + public void builder_ipv6PercentEncodedScopedLiteral() { + Uri uri = Uri.newBuilder().setScheme("http").setRawHost("[fe80::1%25foo%2Dbar%2Fblah]").build(); + assertThat(uri.getRawHost()).isEqualTo("[fe80::1%25foo%2Dbar%2Fblah]"); + assertThat(uri.getHost()).isEqualTo("[fe80::1%foo-bar/blah]"); + } + + @Test + public void builder_encodingWithAllowedReservedChars() throws URISyntaxException { + Uri uri = + Uri.newBuilder() + .setScheme("s") + .setUserInfo("u@") + .setHost("a[]") + .setPath("/p:/@") + .setQuery("q/?") + .setFragment("f/?") + .build(); + assertThat(uri.toString()).isEqualTo("s://u%40@a%5B%5D/p:/@?q/?#f/?"); + } + + @Test + public void builder_percentEncodingNonAscii() throws URISyntaxException { + Uri uri = Uri.newBuilder().setScheme("s").setHost("a").setPath("/€").build(); + assertThat(uri.toString()).isEqualTo("s://a/%E2%82%AC"); + } + + @Test + public void builder_percentEncodingLoneHighSurrogate_throws() { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> Uri.newBuilder().setPath("\uD83D")); // Lone high surrogate. + assertThat(e.getMessage()).contains("Malformed input"); + } + + @Test + public void builder_hasAuthority_pathStartsWithSlash_throws() throws URISyntaxException { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> Uri.newBuilder().setScheme("s").setHost("a").setPath("path").build()); + assertThat(e.getMessage()).contains("Non-empty path must start with '/'"); + } + + @Test + public void builder_noAuthority_pathStartsWithDoubleSlash_throws() throws URISyntaxException { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> Uri.newBuilder().setScheme("s").setPath("//path").build()); + assertThat(e.getMessage()).contains("Path cannot start with '//'"); + } + + @Test + public void builder_noScheme_throws() { + IllegalStateException e = + assertThrows(IllegalStateException.class, () -> Uri.newBuilder().build()); + assertThat(e.getMessage()).contains("Missing required scheme"); + } + + @Test + public void builder_noHost_hasUserInfo_throws() { + IllegalStateException e = + assertThrows( + IllegalStateException.class, + () -> Uri.newBuilder().setScheme("scheme").setUserInfo("user").build()); + assertThat(e.getMessage()).contains("Cannot set userInfo without host"); + } + + @Test + public void builder_noHost_hasPort_throws() { + IllegalStateException e = + assertThrows( + IllegalStateException.class, + () -> Uri.newBuilder().setScheme("scheme").setPort(1234).build()); + assertThat(e.getMessage()).contains("Cannot set port without host"); + } + + @Test + public void builder_normalizesCaseWhereAppropriate() { + Uri uri = + Uri.newBuilder() + .setScheme("hTtP") // #section-3.1 says producers (Builder) should normalize to lower. + .setHost("aBc") // #section-3.2.2 says producers (Builder) should normalize to lower. + .setPath("/CdE") // #section-6.2.2.1 says the rest are assumed to be case-sensitive + .setQuery("fGh") + .setFragment("IjK") + .build(); + assertThat(uri.toString()).isEqualTo("http://abc/CdE?fGh#IjK"); + } + + @Test + public void builder_normalizesIpv6Literal() { + Uri uri = + Uri.newBuilder().setScheme("scheme").setHost(InetAddresses.forString("ABCD::EFAB")).build(); + assertThat(uri.toString()).isEqualTo("scheme://[abcd::efab]"); + } + + @Test + public void builder_canClearAllOptionalFields() { + Uri uri = + Uri.create("http://user@host:80/path?query#fragment").toBuilder() + .setHost((String) null) + .setPath("") + .setUserInfo(null) + .setPort(-1) + .setQuery(null) + .setFragment(null) + .build(); + assertThat(uri.toString()).isEqualTo("http:"); + } + + @Test + public void builder_canClearAuthorityComponents() { + Uri uri = Uri.create("s://user@host:80/path").toBuilder().setRawAuthority(null).build(); + assertThat(uri.toString()).isEqualTo("s:/path"); + } + + @Test + public void builder_canSetEmptyAuthority() { + Uri uri = Uri.create("s://user@host:80/path").toBuilder().setRawAuthority("").build(); + assertThat(uri.toString()).isEqualTo("s:///path"); + } + + @Test + public void builder_canSetRawAuthority() { + Uri uri = Uri.newBuilder().setScheme("http").setRawAuthority("user@host:1234").build(); + assertThat(uri.getUserInfo()).isEqualTo("user"); + assertThat(uri.getHost()).isEqualTo("host"); + assertThat(uri.getPort()).isEqualTo(1234); + } + + @Test + public void builder_setRawAuthorityPercentDecodes() { + Uri uri = + Uri.newBuilder() + .setScheme("http") + .setRawAuthority("user:user%40user@host%40host%3Ahost") + .build(); + assertThat(uri.getUserInfo()).isEqualTo("user:user@user"); + assertThat(uri.getHost()).isEqualTo("host@host:host"); + assertThat(uri.getPort()).isEqualTo(-1); + } + + @Test + public void builder_setRawAuthorityReplacesAllComponents() { + Uri uri = + Uri.newBuilder() + .setScheme("http") + .setUserInfo("user") + .setHost("host") + .setPort(1234) + .setRawAuthority("other") + .build(); + assertThat(uri.getUserInfo()).isNull(); + assertThat(uri.getHost()).isEqualTo("other"); + assertThat(uri.getPort()).isEqualTo(-1); + } + + @Test + public void toString_percentEncodingMultiChar() throws URISyntaxException { + Uri uri = + Uri.newBuilder() + .setScheme("s") + .setHost("a") + .setPath("/emojis/😊/icon.png") // Smile requires two chars to express in a java String. + .build(); + assertThat(uri.toString()).isEqualTo("s://a/emojis/%F0%9F%98%8A/icon.png"); + } + + @Test + public void toString_percentEncodingLiteralPercent() throws URISyntaxException { + Uri uri = + Uri.newBuilder() + .setScheme("s") + .setHost("a") + .setPath("/p%20ath") + .setQuery("q%uery") + .setFragment("f%ragment") + .build(); + assertThat(uri.toString()).isEqualTo("s://a/p%2520ath?q%25uery#f%25ragment"); + } + + @Test + public void equalsAndHashCode() { + new EqualsTester() + .addEqualityGroup( + Uri.create("scheme://authority/path?query#fragment"), + Uri.create("scheme://authority/path?query#fragment")) + .addEqualityGroup(Uri.create("scheme://authority/path")) + .addEqualityGroup(Uri.create("scheme://authority/path?query")) + .addEqualityGroup(Uri.create("scheme:/path")) + .addEqualityGroup(Uri.create("scheme:/path?query")) + .addEqualityGroup(Uri.create("scheme:/path#fragment")) + .addEqualityGroup(Uri.create("scheme:path")) + .addEqualityGroup(Uri.create("scheme:path?query")) + .addEqualityGroup(Uri.create("scheme:path#fragment")) + .addEqualityGroup(Uri.create("scheme:")) + .testEquals(); + } + + @Test + public void isAbsolute() { + assertThat(Uri.create("scheme://authority/path").isAbsolute()).isTrue(); + assertThat(Uri.create("scheme://authority/path?query").isAbsolute()).isTrue(); + assertThat(Uri.create("scheme://authority/path#fragment").isAbsolute()).isFalse(); + assertThat(Uri.create("scheme://authority/path?query#fragment").isAbsolute()).isFalse(); + } + + @Test + public void serializedCharacterClasses_matchComputed() { + assertThat(Uri.digitChars).isEqualTo(bitSetOfRange('0', '9')); + assertThat(Uri.alphaChars).isEqualTo(or(bitSetOfRange('A', 'Z'), bitSetOfRange('a', 'z'))); + assertThat(Uri.schemeChars) + .isEqualTo(or(Uri.digitChars, Uri.alphaChars, bitSetOf('+', '-', '.'))); + assertThat(Uri.unreservedChars) + .isEqualTo(or(Uri.alphaChars, Uri.digitChars, bitSetOf('-', '.', '_', '~'))); + assertThat(Uri.genDelimsChars).isEqualTo(bitSetOf(':', '/', '?', '#', '[', ']', '@')); + assertThat(Uri.subDelimsChars) + .isEqualTo(bitSetOf('!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=')); + assertThat(Uri.reservedChars).isEqualTo(or(Uri.genDelimsChars, Uri.subDelimsChars)); + assertThat(Uri.regNameChars).isEqualTo(or(Uri.unreservedChars, Uri.subDelimsChars)); + assertThat(Uri.userInfoChars) + .isEqualTo(or(Uri.unreservedChars, Uri.subDelimsChars, bitSetOf(':'))); + assertThat(Uri.pChars) + .isEqualTo(or(Uri.unreservedChars, Uri.subDelimsChars, bitSetOf(':', '@'))); + assertThat(Uri.pCharsAndSlash).isEqualTo(or(Uri.pChars, bitSetOf('/'))); + assertThat(Uri.queryChars).isEqualTo(or(Uri.pChars, bitSetOf('/', '?'))); + assertThat(Uri.fragmentChars).isEqualTo(or(Uri.pChars, bitSetOf('/', '?'))); + } + + private static BitSet bitSetOfRange(char from, char to) { + BitSet bitset = new BitSet(); + for (char c = from; c <= to; c++) { + bitset.set(c); + } + return bitset; + } + + private static BitSet bitSetOf(char... chars) { + BitSet bitset = new BitSet(); + for (char c : chars) { + bitset.set(c); + } + return bitset; + } + + private static BitSet or(BitSet... bitsets) { + BitSet bitset = new BitSet(); + for (BitSet bs : bitsets) { + bitset.or(bs); + } + return bitset; + } +} diff --git a/api/src/testFixtures/java/io/grpc/FlagResetRule.java b/api/src/testFixtures/java/io/grpc/FlagResetRule.java new file mode 100644 index 00000000000..08ce7ce82f2 --- /dev/null +++ b/api/src/testFixtures/java/io/grpc/FlagResetRule.java @@ -0,0 +1,96 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.ArrayDeque; +import java.util.Deque; +import javax.annotation.Nullable; +import org.junit.rules.ExternalResource; + +/** + * A {@link org.junit.rules.TestRule} that lets you set one or more feature flags just for + * the duration of the current test case. + * + *

Flags and other global variables must be reset to ensure no state leaks across tests. + */ +public final class FlagResetRule extends ExternalResource { + + /** A functional interface representing a standard gRPC feature flag setter. */ + public interface SetterMethod { + /** Sets a flag for testing and returns its previous value. */ + T set(T val); + } + + private final Deque toRunAfter = new ArrayDeque<>(); + + /** + * Sets a global feature flag to 'value' using 'setter' and arranges for its previous value to be + * unconditionally restored when the test completes. + */ + public void setFlagForTest(SetterMethod setter, T value) { + final T oldValue = setter.set(value); + toRunAfter.push(() -> setter.set(oldValue)); + } + + /** + * Sets java system property 'key' to 'value' and arranges for its previous value to be + * unconditionally restored when the test completes. + */ + public void setSystemPropertyForTest(String key, String value) { + String oldValue = System.setProperty(key, value); + restoreSystemPropertyAfterTest(key, oldValue); + } + + /** + * Clears java system property 'key' and arranges for its previous value to be unconditionally + * restored when the test completes. + */ + public void clearSystemPropertyForTest(String key) { + String oldValue = System.clearProperty(key); + restoreSystemPropertyAfterTest(key, oldValue); + } + + private void restoreSystemPropertyAfterTest(String key, @Nullable String oldValue) { + toRunAfter.push( + () -> { + if (oldValue == null) { + System.clearProperty(key); + } else { + System.setProperty(key, oldValue); + } + }); + } + + @Override + protected void after() { + RuntimeException toThrow = null; + while (!toRunAfter.isEmpty()) { + try { + toRunAfter.pop().run(); + } catch (RuntimeException e) { + if (toThrow == null) { + toThrow = e; + } else { + toThrow.addSuppressed(e); + } + } + } + if (toThrow != null) { + throw toThrow; + } + } +} diff --git a/api/src/testFixtures/java/io/grpc/StatusMatcher.java b/api/src/testFixtures/java/io/grpc/StatusMatcher.java new file mode 100644 index 00000000000..f464b2d709d --- /dev/null +++ b/api/src/testFixtures/java/io/grpc/StatusMatcher.java @@ -0,0 +1,118 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import org.mockito.ArgumentMatcher; + +/** + * Mockito matcher for {@link Status}. + */ +public final class StatusMatcher implements ArgumentMatcher { + public static StatusMatcher statusHasCode(ArgumentMatcher codeMatcher) { + return new StatusMatcher(codeMatcher, null); + } + + public static StatusMatcher statusHasCode(Status.Code code) { + return statusHasCode(new EqualsMatcher<>(code)); + } + + private final ArgumentMatcher codeMatcher; + private final ArgumentMatcher descriptionMatcher; + + private StatusMatcher( + ArgumentMatcher codeMatcher, + ArgumentMatcher descriptionMatcher) { + this.codeMatcher = checkNotNull(codeMatcher, "codeMatcher"); + this.descriptionMatcher = descriptionMatcher; + } + + public StatusMatcher andDescription(ArgumentMatcher descriptionMatcher) { + checkState(this.descriptionMatcher == null, "Already has a description matcher"); + return new StatusMatcher(codeMatcher, descriptionMatcher); + } + + public StatusMatcher andDescription(String description) { + return andDescription(new EqualsMatcher<>(description)); + } + + public StatusMatcher andDescriptionContains(String substring) { + return andDescription(new StringContainsMatcher(substring)); + } + + @Override + public boolean matches(Status status) { + return status != null + && codeMatcher.matches(status.getCode()) + && (descriptionMatcher == null || descriptionMatcher.matches(status.getDescription())); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("{code="); + sb.append(codeMatcher); + if (descriptionMatcher != null) { + sb.append(", description="); + sb.append(descriptionMatcher); + } + sb.append("}"); + return sb.toString(); + } + + // Use instead of lambda for better error message. + static final class EqualsMatcher implements ArgumentMatcher { + private final T obj; + + EqualsMatcher(T obj) { + this.obj = checkNotNull(obj, "obj"); + } + + @Override + public boolean matches(Object other) { + return obj.equals(other); + } + + @Override + public String toString() { + return obj.toString(); + } + } + + static final class StringContainsMatcher implements ArgumentMatcher { + private final String needle; + + StringContainsMatcher(String needle) { + this.needle = checkNotNull(needle, "needle"); + } + + @Override + public boolean matches(String haystack) { + if (haystack == null) { + return false; + } + return haystack.contains(needle); + } + + @Override + public String toString() { + return "contains " + needle; + } + } +} diff --git a/api/src/testFixtures/java/io/grpc/StatusOrMatcher.java b/api/src/testFixtures/java/io/grpc/StatusOrMatcher.java new file mode 100644 index 00000000000..1e70ae97853 --- /dev/null +++ b/api/src/testFixtures/java/io/grpc/StatusOrMatcher.java @@ -0,0 +1,66 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkNotNull; + +import org.mockito.ArgumentMatcher; + +/** + * Mockito matcher for {@link StatusOr}. + */ +public final class StatusOrMatcher implements ArgumentMatcher> { + public static StatusOrMatcher hasValue(ArgumentMatcher valueMatcher) { + return new StatusOrMatcher(checkNotNull(valueMatcher, "valueMatcher"), null); + } + + public static StatusOrMatcher hasStatus(ArgumentMatcher statusMatcher) { + return new StatusOrMatcher(null, checkNotNull(statusMatcher, "statusMatcher")); + } + + private final ArgumentMatcher valueMatcher; + private final ArgumentMatcher statusMatcher; + + private StatusOrMatcher(ArgumentMatcher valueMatcher, ArgumentMatcher statusMatcher) { + this.valueMatcher = valueMatcher; + this.statusMatcher = statusMatcher; + } + + @Override + public boolean matches(StatusOr statusOr) { + if (statusOr == null) { + return false; + } + if (statusOr.hasValue() != (valueMatcher != null)) { + return false; + } + if (valueMatcher != null) { + return valueMatcher.matches(statusOr.getValue()); + } else { + return statusMatcher.matches(statusOr.getStatus()); + } + } + + @Override + public String toString() { + if (valueMatcher != null) { + return "{value=" + valueMatcher + "}"; + } else { + return "{status=" + statusMatcher + "}"; + } + } +} diff --git a/api/src/testFixtures/java/io/grpc/testing/DeadlineSubject.java b/api/src/testFixtures/java/io/grpc/testing/DeadlineSubject.java index 5d4e86fac15..c2b4d8412a7 100644 --- a/api/src/testFixtures/java/io/grpc/testing/DeadlineSubject.java +++ b/api/src/testFixtures/java/io/grpc/testing/DeadlineSubject.java @@ -24,9 +24,9 @@ import com.google.common.truth.ComparableSubject; import com.google.common.truth.FailureMetadata; import com.google.common.truth.Subject; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Deadline; import java.util.concurrent.TimeUnit; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; /** Propositions for {@link Deadline} subjects. */ @@ -67,7 +67,7 @@ public void of(Deadline expected) { if (Math.abs(actualNanos - expectedNanos) > deltaNanos) { failWithoutActual( fact("expected", expectedNanos / NANOSECONDS_IN_A_SECOND), - fact("but was", expectedNanos / NANOSECONDS_IN_A_SECOND), + fact("but was", actualNanos / NANOSECONDS_IN_A_SECOND), fact("outside tolerance in seconds", deltaNanos / NANOSECONDS_IN_A_SECOND)); } } diff --git a/auth/BUILD.bazel b/auth/BUILD.bazel index a19562fa7f7..da44243e583 100644 --- a/auth/BUILD.bazel +++ b/auth/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( diff --git a/auth/build.gradle b/auth/build.gradle index 78bb720601b..d56802c14ca 100644 --- a/auth/build.gradle +++ b/auth/build.gradle @@ -22,6 +22,14 @@ dependencies { project(':grpc-core'), project(":grpc-context"), // Override google-auth dependency with our newer version libraries.google.auth.oauth2Http - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } diff --git a/auth/src/test/java/io/grpc/auth/GoogleAuthLibraryCallCredentialsTest.java b/auth/src/test/java/io/grpc/auth/GoogleAuthLibraryCallCredentialsTest.java index 1e8c27bca25..75026fd7c18 100644 --- a/auth/src/test/java/io/grpc/auth/GoogleAuthLibraryCallCredentialsTest.java +++ b/auth/src/test/java/io/grpc/auth/GoogleAuthLibraryCallCredentialsTest.java @@ -50,10 +50,12 @@ import io.grpc.Status; import io.grpc.internal.JsonParser; import io.grpc.testing.TestMethodDescriptors; +import io.grpc.testing.TlsTesting; +import io.grpc.util.CertificateUtils; import java.io.IOException; +import java.io.InputStream; import java.net.URI; -import java.security.KeyPair; -import java.security.KeyPairGenerator; +import java.security.PrivateKey; import java.util.ArrayList; import java.util.Date; import java.util.List; @@ -342,7 +344,10 @@ public void serviceUri() throws Exception { @Test public void serviceAccountToJwt() throws Exception { - KeyPair pair = KeyPairGenerator.getInstance("RSA").generateKeyPair(); + PrivateKey privateKey; + try (InputStream server1Key = TlsTesting.loadCert("server1.key")) { + privateKey = CertificateUtils.getPrivateKey(server1Key); + } HttpTransportFactory factory = Mockito.mock(HttpTransportFactory.class); Mockito.when(factory.create()).thenThrow(new AssertionError()); @@ -350,7 +355,7 @@ public void serviceAccountToJwt() throws Exception { ServiceAccountCredentials credentials = ServiceAccountCredentials.newBuilder() .setClientEmail("test-email@example.com") - .setPrivateKey(pair.getPrivate()) + .setPrivateKey(privateKey) .setPrivateKeyId("test-private-key-id") .setHttpTransportFactory(factory) .build(); @@ -390,13 +395,16 @@ public void oauthClassesNotInClassPath() throws Exception { @Test public void jwtAccessCredentialsInRequestMetadata() throws Exception { - KeyPair pair = KeyPairGenerator.getInstance("RSA").generateKeyPair(); + PrivateKey privateKey; + try (InputStream server1Key = TlsTesting.loadCert("server1.key")) { + privateKey = CertificateUtils.getPrivateKey(server1Key); + } ServiceAccountCredentials credentials = ServiceAccountCredentials.newBuilder() .setClientId("test-client") .setClientEmail("test-email@example.com") - .setPrivateKey(pair.getPrivate()) + .setPrivateKey(privateKey) .setPrivateKeyId("test-private-key-id") .setQuotaProjectId("test-quota-project-id") .build(); diff --git a/authz/build.gradle b/authz/build.gradle index 491e8f32a74..4b02b01aa29 100644 --- a/authz/build.gradle +++ b/authz/build.gradle @@ -2,8 +2,8 @@ plugins { id "java-library" id "maven-publish" - id "com.github.johnrengelman.shadow" id "com.google.protobuf" + id "com.gradleup.shadow" id "ru.vyarus.animalsniffer" } @@ -15,7 +15,6 @@ dependencies { libraries.guava.jre // JRE required by transitive protobuf-java-util annotationProcessor libraries.auto.value - compileOnly libraries.javax.annotation testImplementation project(':grpc-testing'), project(':grpc-testing-proto'), @@ -26,7 +25,11 @@ dependencies { shadow configurations.implementation.getDependencies().minus([xdsDependency]) shadow project(path: ':grpc-xds', configuration: 'shadow') - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } tasks.named("jar").configure { diff --git a/authz/src/main/java/io/grpc/authz/AuthorizationPolicyTranslator.java b/authz/src/main/java/io/grpc/authz/AuthorizationPolicyTranslator.java index ed7e018412c..183ae2c3f55 100644 --- a/authz/src/main/java/io/grpc/authz/AuthorizationPolicyTranslator.java +++ b/authz/src/main/java/io/grpc/authz/AuthorizationPolicyTranslator.java @@ -156,19 +156,19 @@ private static Map parseRules( } /** - * Translates a gRPC authorization policy in JSON string to Envoy RBAC policies. - * On success, will return one of the following - - * 1. One allow RBAC policy or, - * 2. Two RBAC policies, deny policy followed by allow policy. - * If the policy cannot be parsed or is invalid, an exception will be thrown. - */ + * Translates a gRPC authorization policy in JSON string to Envoy RBAC policies. + * On success, will return one of the following - + * 1. One allow RBAC policy or, + * 2. Two RBAC policies, deny policy followed by allow policy. + * If the policy cannot be parsed or is invalid, an exception will be thrown. + */ public static List translate(String authorizationPolicy) throws IllegalArgumentException, IOException { Object jsonObject = JsonParser.parse(authorizationPolicy); if (!(jsonObject instanceof Map)) { throw new IllegalArgumentException( - "Authorization policy should be a JSON object. Found: " - + (jsonObject == null ? null : jsonObject.getClass())); + "Authorization policy should be a JSON object. Found: " + + (jsonObject == null ? null : jsonObject.getClass())); } @SuppressWarnings("unchecked") Map json = (Map)jsonObject; diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle index d00c0d76ebe..88b26397e78 100644 --- a/benchmarks/build.gradle +++ b/benchmarks/build.gradle @@ -38,12 +38,15 @@ dependencies { classifier = "linux-x86_64" } } - compileOnly libraries.javax.annotation testImplementation libraries.junit, libraries.mockito.core - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } import net.ltgt.gradle.errorprone.CheckSeverity diff --git a/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/BenchmarkServiceGrpc.java b/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/BenchmarkServiceGrpc.java index e62c2274ee9..68e911afc4a 100644 --- a/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/BenchmarkServiceGrpc.java +++ b/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/BenchmarkServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/services.proto") @io.grpc.stub.annotations.GrpcGenerated public final class BenchmarkServiceGrpc { @@ -184,6 +181,21 @@ public BenchmarkServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return BenchmarkServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static BenchmarkServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public BenchmarkServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new BenchmarkServiceBlockingV2Stub(channel, callOptions); + } + }; + return BenchmarkServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -367,6 +379,87 @@ public io.grpc.stub.StreamObserver { + private BenchmarkServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected BenchmarkServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new BenchmarkServiceBlockingV2Stub(channel, callOptions); + } + + /** + *

+     * One request followed by one response.
+     * The server returns the client payload as-is.
+     * 
+ */ + public io.grpc.benchmarks.proto.Messages.SimpleResponse unaryCall(io.grpc.benchmarks.proto.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * Repeated sequence of one request followed by one response.
+     * Should be called streaming ping-pong
+     * The server returns the client payload as-is on each response
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getStreamingCallMethod(), getCallOptions()); + } + + /** + *
+     * Single-sided unbounded streaming from client to server
+     * The server returns the client payload as-is once the client does WritesDone
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingFromClient() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getStreamingFromClientMethod(), getCallOptions()); + } + + /** + *
+     * Single-sided unbounded streaming from server to client
+     * The server repeatedly returns the client payload as-is
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingFromServer(io.grpc.benchmarks.proto.Messages.SimpleRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamingFromServerMethod(), getCallOptions(), request); + } + + /** + *
+     * Two-sided unbounded streaming between server to client
+     * Both sides send the content of their own choice to the other
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingBothWays() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getStreamingBothWaysMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service BenchmarkService. + */ public static final class BenchmarkServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private BenchmarkServiceBlockingStub( diff --git a/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/ReportQpsScenarioServiceGrpc.java b/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/ReportQpsScenarioServiceGrpc.java index b24c3813c19..c5064875bb6 100644 --- a/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/ReportQpsScenarioServiceGrpc.java +++ b/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/ReportQpsScenarioServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/services.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ReportQpsScenarioServiceGrpc { @@ -60,6 +57,21 @@ public ReportQpsScenarioServiceStub newStub(io.grpc.Channel channel, io.grpc.Cal return ReportQpsScenarioServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ReportQpsScenarioServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ReportQpsScenarioServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReportQpsScenarioServiceBlockingV2Stub(channel, callOptions); + } + }; + return ReportQpsScenarioServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -147,6 +159,33 @@ public void reportScenario(io.grpc.benchmarks.proto.Control.ScenarioResult reque /** * A stub to allow clients to do synchronous rpc calls to service ReportQpsScenarioService. */ + public static final class ReportQpsScenarioServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ReportQpsScenarioServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ReportQpsScenarioServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReportQpsScenarioServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Report results of a QPS test benchmark scenario.
+     * 
+ */ + public io.grpc.benchmarks.proto.Control.Void reportScenario(io.grpc.benchmarks.proto.Control.ScenarioResult request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getReportScenarioMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ReportQpsScenarioService. + */ public static final class ReportQpsScenarioServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ReportQpsScenarioServiceBlockingStub( diff --git a/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/WorkerServiceGrpc.java b/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/WorkerServiceGrpc.java index 0ee6797c8e3..721b4f9ab19 100644 --- a/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/WorkerServiceGrpc.java +++ b/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/WorkerServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/services.proto") @io.grpc.stub.annotations.GrpcGenerated public final class WorkerServiceGrpc { @@ -153,6 +150,21 @@ public WorkerServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions ca return WorkerServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static WorkerServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public WorkerServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new WorkerServiceBlockingV2Stub(channel, callOptions); + } + }; + return WorkerServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -323,6 +335,77 @@ public void quitWorker(io.grpc.benchmarks.proto.Control.Void request, /** * A stub to allow clients to do synchronous rpc calls to service WorkerService. */ + public static final class WorkerServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private WorkerServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected WorkerServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new WorkerServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Start server with specified workload.
+     * First request sent specifies the ServerConfig followed by ServerStatus
+     * response. After that, a "Mark" can be sent anytime to request the latest
+     * stats. Closing the stream will initiate shutdown of the test server
+     * and once the shutdown has finished, the OK status is sent to terminate
+     * this RPC.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + runServer() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getRunServerMethod(), getCallOptions()); + } + + /** + *
+     * Start client with specified workload.
+     * First request sent specifies the ClientConfig followed by ClientStatus
+     * response. After that, a "Mark" can be sent anytime to request the latest
+     * stats. Closing the stream will initiate shutdown of the test client
+     * and once the shutdown has finished, the OK status is sent to terminate
+     * this RPC.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + runClient() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getRunClientMethod(), getCallOptions()); + } + + /** + *
+     * Just return the core count - unary call
+     * 
+ */ + public io.grpc.benchmarks.proto.Control.CoreResponse coreCount(io.grpc.benchmarks.proto.Control.CoreRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getCoreCountMethod(), getCallOptions(), request); + } + + /** + *
+     * Quit this worker
+     * 
+ */ + public io.grpc.benchmarks.proto.Control.Void quitWorker(io.grpc.benchmarks.proto.Control.Void request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getQuitWorkerMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service WorkerService. + */ public static final class WorkerServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private WorkerServiceBlockingStub( diff --git a/binder/build.gradle b/binder/build.gradle index e18361e08b3..0da3f97ceee 100644 --- a/binder/build.gradle +++ b/binder/build.gradle @@ -6,33 +6,31 @@ plugins { description = 'gRPC BinderChannel' android { - namespace 'io.grpc.binder' + namespace = 'io.grpc.binder' compileSdkVersion 34 compileOptions { sourceCompatibility 1.8 targetCompatibility 1.8 } defaultConfig { - minSdkVersion 21 + minSdkVersion 23 targetSdkVersion 33 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" - multiDexEnabled true } - lintOptions { abortOnError false } + lintOptions { abortOnError = false } publishing { singleVariant('release') { withSourcesJar() withJavadocJar() } } - testFixtures { enable true } + testFixtures { enable = true } } repositories { google() - mavenCentral() } dependencies { @@ -73,6 +71,7 @@ dependencies { androidTestImplementation testFixtures(project(':grpc-core')) testFixturesImplementation libraries.guava.testlib + testFixturesImplementation testFixtures(project(':grpc-core')) } import net.ltgt.gradle.errorprone.CheckSeverity diff --git a/binder/src/androidTest/AndroidManifest.xml b/binder/src/androidTest/AndroidManifest.xml index b6d71574410..44f21e104d9 100644 --- a/binder/src/androidTest/AndroidManifest.xml +++ b/binder/src/androidTest/AndroidManifest.xml @@ -11,11 +11,13 @@ + + diff --git a/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java b/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java index 79f7b98f045..4e3cfcf0d05 100644 --- a/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java @@ -23,6 +23,7 @@ import android.content.Context; import android.content.Intent; +import android.net.Uri; import android.os.Parcel; import android.os.Parcelable; import androidx.test.core.app.ApplicationProvider; @@ -39,7 +40,6 @@ import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; -import io.grpc.NameResolverRegistry; import io.grpc.ServerCall; import io.grpc.ServerCall.Listener; import io.grpc.ServerCallHandler; @@ -49,7 +49,6 @@ import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; import io.grpc.internal.GrpcUtil; -import io.grpc.internal.testing.FakeNameResolverProvider; import io.grpc.stub.ClientCalls; import io.grpc.stub.MetadataUtils; import io.grpc.stub.ServerCalls; @@ -77,7 +76,6 @@ public final class BinderChannelSmokeTest { private static final int SLIGHTLY_MORE_THAN_ONE_BLOCK = 16 * 1024 + 100; private static final String MSG = "Some text which will be repeated many many times"; - private static final String SERVER_TARGET_URI = "fake://server"; private static final Metadata.Key POISON_KEY = ParcelableUtils.metadataKey("poison-bin", PoisonParcelable.CREATOR); @@ -99,7 +97,7 @@ public final class BinderChannelSmokeTest { .setType(MethodDescriptor.MethodType.BIDI_STREAMING) .build(); - FakeNameResolverProvider fakeNameResolverProvider; + AndroidComponentAddress serverAddress; ManagedChannel channel; AtomicReference headersCapture = new AtomicReference<>(); AtomicReference clientUidCapture = new AtomicReference<>(); @@ -137,9 +135,7 @@ public void setUp() throws Exception { TestUtils.recordRequestHeadersInterceptor(headersCapture), PeerUids.newPeerIdentifyingServerInterceptor()); - AndroidComponentAddress serverAddress = HostServices.allocateService(appContext); - fakeNameResolverProvider = new FakeNameResolverProvider(SERVER_TARGET_URI, serverAddress); - NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverProvider); + serverAddress = HostServices.allocateService(appContext); HostServices.configureService( serverAddress, HostServices.serviceParamsBuilder() @@ -154,19 +150,20 @@ public void setUp() throws Exception { .build()) .build()); - channel = - BinderChannelBuilder.forAddress(serverAddress, appContext) + channel = newBinderChannelBuilder().build(); + } + + BinderChannelBuilder newBinderChannelBuilder() { + return BinderChannelBuilder.forAddress(serverAddress, appContext) .inboundParcelablePolicy( - InboundParcelablePolicy.newBuilder() - .setAcceptParcelableMetadataValues(true) - .build()) - .build(); + InboundParcelablePolicy.newBuilder() + .setAcceptParcelableMetadataValues(true) + .build()); } @After public void tearDown() throws Exception { channel.shutdownNow(); - NameResolverRegistry.getDefaultRegistry().deregister(fakeNameResolverProvider); HostServices.awaitServiceShutdown(); } @@ -191,6 +188,18 @@ public void testBasicCall() throws Exception { assertThat(doCall("Hello").get()).isEqualTo("Hello"); } + @Test + public void testBasicCallWithLegacyAuthStrategy() throws Exception { + channel = newBinderChannelBuilder().useLegacyAuthStrategy().build(); + assertThat(doCall("Hello").get()).isEqualTo("Hello"); + } + + @Test + public void testBasicCallWithV2AuthStrategy() throws Exception { + channel = newBinderChannelBuilder().useV2AuthStrategy().build(); + assertThat(doCall("Hello").get()).isEqualTo("Hello"); + } + @Test public void testPeerUidIsRecorded() throws Exception { assertThat(doCall("Hello").get()).isEqualTo("Hello"); @@ -235,7 +244,11 @@ public void testStreamingCallOptionHeaders() throws Exception { @Test public void testConnectViaTargetUri() throws Exception { - channel = BinderChannelBuilder.forTarget(SERVER_TARGET_URI, appContext).build(); + // Compare with the mapping in AndroidManifest.xml. + channel = + BinderChannelBuilder.forTarget( + "intent://authority/path#Intent;action=action1;scheme=scheme;end;", appContext) + .build(); assertThat(doCall("Hello").get()).isEqualTo("Hello"); } @@ -245,7 +258,10 @@ public void testConnectViaIntentFilter() throws Exception { channel = BinderChannelBuilder.forAddress( AndroidComponentAddress.forBindIntent( - new Intent().setAction("action1").setPackage(appContext.getPackageName())), + new Intent() + .setAction("action1") + .setData(Uri.parse("scheme://authority/path")) + .setPackage(appContext.getPackageName())), appContext) .build(); assertThat(doCall("Hello").get()).isEqualTo("Hello"); diff --git a/binder/src/androidTest/java/io/grpc/binder/HostServices.java b/binder/src/androidTest/java/io/grpc/binder/HostServices.java index 4aa46e8254a..5d4a06a27fe 100644 --- a/binder/src/androidTest/java/io/grpc/binder/HostServices.java +++ b/binder/src/androidTest/java/io/grpc/binder/HostServices.java @@ -29,6 +29,7 @@ import androidx.lifecycle.LifecycleService; import com.google.auto.value.AutoValue; import com.google.common.base.Supplier; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Server; import java.io.IOException; import java.util.HashMap; @@ -38,7 +39,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * A test helper class for creating android services to host gRPC servers. diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java index c84a1fc296f..aa3fb573ab5 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java @@ -17,6 +17,7 @@ package io.grpc.binder.internal; import static com.google.common.truth.Truth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; import android.content.Context; import android.os.DeadObjectException; @@ -24,9 +25,9 @@ import android.os.RemoteException; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.protobuf.Empty; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; @@ -37,16 +38,18 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.binder.AndroidComponentAddress; -import io.grpc.binder.AsyncSecurityPolicy; import io.grpc.binder.BinderServerBuilder; import io.grpc.binder.HostServices; import io.grpc.binder.SecurityPolicy; +import io.grpc.binder.internal.FakeDeadBinder; import io.grpc.binder.internal.OneWayBinderProxies.BlackHoleOneWayBinderProxy; import io.grpc.binder.internal.OneWayBinderProxies.BlockingBinderDecorator; import io.grpc.binder.internal.OneWayBinderProxies.ThrowingOneWayBinderProxy; +import io.grpc.binder.internal.SettableAsyncSecurityPolicy.AuthRequest; import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; +import io.grpc.internal.DisconnectError; import io.grpc.internal.FixedObjectPool; import io.grpc.internal.ManagedClientTransport; import io.grpc.internal.ObjectPool; @@ -62,9 +65,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -101,7 +102,7 @@ public final class BinderClientTransportTest { .build(); AndroidComponentAddress serverAddress; - BinderTransport.BinderClientTransport transport; + BinderClientTransport transport; BlockingSecurityPolicy blockingSecurityPolicy = new BlockingSecurityPolicy(); private final ObjectPool executorServicePool = @@ -154,23 +155,32 @@ private class BinderClientTransportBuilder { .setScheduledExecutorPool(executorServicePool) .setOffloadExecutorPool(offloadServicePool); + @CanIgnoreReturnValue public BinderClientTransportBuilder setSecurityPolicy(SecurityPolicy securityPolicy) { factoryBuilder.setSecurityPolicy(securityPolicy); return this; } + @CanIgnoreReturnValue public BinderClientTransportBuilder setBinderDecorator( OneWayBinderProxy.Decorator binderDecorator) { factoryBuilder.setBinderDecorator(binderDecorator); return this; } + @CanIgnoreReturnValue public BinderClientTransportBuilder setReadyTimeoutMillis(int timeoutMillis) { factoryBuilder.setReadyTimeoutMillis(timeoutMillis); return this; } - public BinderTransport.BinderClientTransport build() { + @CanIgnoreReturnValue + public BinderClientTransportBuilder setPreAuthorizeServer(boolean preAuthorizeServer) { + factoryBuilder.setPreAuthorizeServers(preAuthorizeServer); + return this; + } + + public BinderClientTransport build() { return factoryBuilder .buildClientTransportFactory() .newClientTransport(serverAddress, new ClientTransportOptions(), null); @@ -189,7 +199,7 @@ public void tearDown() throws Exception { private static void shutdownAndTerminate(ExecutorService executorService) throws InterruptedException { executorService.shutdownNow(); - if (!executorService.awaitTermination(TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + if (!executorService.awaitTermination(TIMEOUT_SECONDS, SECONDS)) { throw new AssertionError("executor failed to terminate promptly"); } } @@ -350,6 +360,20 @@ public void testTxnFailurePostSetup() throws Exception { assertThat(streamStatus.getCause()).isSameInstanceAs(doe); } + @Test + public void testServerBinderDeadOnArrival() throws Exception { + BlockingBinderDecorator decorator = new BlockingBinderDecorator<>(); + transport = new BinderClientTransportBuilder().setBinderDecorator(decorator).build(); + transport.start(transportListener).run(); + decorator.putNextResult(decorator.takeNextRequest()); // Server's "Endpoint" Binder. + OneWayBinderProxy unusedServerBinder = decorator.takeNextRequest(); + decorator.putNextResult( + OneWayBinderProxy.wrap(new FakeDeadBinder(), offloadServicePool.getObject())); + Status clientStatus = transportListener.awaitShutdown(); + assertThat(clientStatus.getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(clientStatus.getDescription()).contains("Failed to observe outgoing binder"); + } + @Test public void testBlackHoleEndpointConnectTimeout() throws Exception { BlockingBinderDecorator decorator = new BlockingBinderDecorator<>(); @@ -370,27 +394,58 @@ public void testBlackHoleEndpointConnectTimeout() throws Exception { } @Test - public void testBlackHoleSecurityPolicyConnectTimeout() throws Exception { + public void testBlackHoleSecurityPolicyAuthTimeout() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); transport = new BinderClientTransportBuilder() - .setSecurityPolicy(blockingSecurityPolicy) + .setSecurityPolicy(securityPolicy) + .setPreAuthorizeServer(false) .setReadyTimeoutMillis(1_234) .build(); transport.start(transportListener).run(); + // Take the next authRequest but don't respond to it, in order to trigger the ready timeout. + AuthRequest authRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS); + Status transportStatus = transportListener.awaitShutdown(); assertThat(transportStatus.getCode()).isEqualTo(Code.DEADLINE_EXCEEDED); assertThat(transportStatus.getDescription()).contains("1234"); transportListener.awaitTermination(); - blockingSecurityPolicy.provideNextCheckAuthorizationResult(Status.OK); + // If the transport gave up waiting on auth, it should cancel its request. + assertThat(authRequest.isCancelled()).isTrue(); } @Test - public void testAsyncSecurityPolicyFailure() throws Exception { + public void testBlackHoleSecurityPolicyPreAuthTimeout() throws Exception { SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); - transport = new BinderClientTransportBuilder().setSecurityPolicy(securityPolicy).build(); + transport = + new BinderClientTransportBuilder() + .setSecurityPolicy(securityPolicy) + .setPreAuthorizeServer(true) + .setReadyTimeoutMillis(1_234) + .build(); + transport.start(transportListener).run(); + // Take the next authRequest but don't respond to it, in order to trigger the ready timeout. + AuthRequest preAuthRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS); + + Status transportStatus = transportListener.awaitShutdown(); + assertThat(transportStatus.getCode()).isEqualTo(Code.DEADLINE_EXCEEDED); + assertThat(transportStatus.getDescription()).contains("1234"); + transportListener.awaitTermination(); + // If the transport gave up waiting on auth, it should cancel its request. + assertThat(preAuthRequest.isCancelled()).isTrue(); + } + + @Test + public void testAsyncSecurityPolicyAuthFailure() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = + new BinderClientTransportBuilder() + .setPreAuthorizeServer(false) + .setSecurityPolicy(securityPolicy) + .build(); RuntimeException exception = new NullPointerException(); - securityPolicy.setAuthorizationException(exception); transport.start(transportListener).run(); + securityPolicy.takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS).setResult(exception); Status transportStatus = transportListener.awaitShutdown(); assertThat(transportStatus.getCode()).isEqualTo(Code.INTERNAL); assertThat(transportStatus.getCause()).isEqualTo(exception); @@ -398,19 +453,72 @@ public void testAsyncSecurityPolicyFailure() throws Exception { } @Test - public void testAsyncSecurityPolicySuccess() throws Exception { + public void testAsyncSecurityPolicyPreAuthFailure() throws Exception { SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); - transport = new BinderClientTransportBuilder().setSecurityPolicy(securityPolicy).build(); - securityPolicy.setAuthorizationResult(Status.PERMISSION_DENIED); + transport = + new BinderClientTransportBuilder() + .setPreAuthorizeServer(true) + .setSecurityPolicy(securityPolicy) + .build(); + RuntimeException exception = new NullPointerException(); + transport.start(transportListener).run(); + securityPolicy.takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS).setResult(exception); + Status transportStatus = transportListener.awaitShutdown(); + assertThat(transportStatus.getCode()).isEqualTo(Code.INTERNAL); + assertThat(transportStatus.getCause()).isEqualTo(exception); + transportListener.awaitTermination(); + } + + @Test + public void testAsyncSecurityPolicyAuthSuccess() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = + new BinderClientTransportBuilder() + .setPreAuthorizeServer(false) + .setSecurityPolicy(securityPolicy) + .build(); transport.start(transportListener).run(); + securityPolicy + .takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS) + .setResult(Status.PERMISSION_DENIED.withDescription("xyzzy")); Status transportStatus = transportListener.awaitShutdown(); assertThat(transportStatus.getCode()).isEqualTo(Code.PERMISSION_DENIED); + assertThat(transportStatus.getDescription()).contains("xyzzy"); transportListener.awaitTermination(); } + @Test + public void testAsyncSecurityPolicyPreAuthSuccess() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = + new BinderClientTransportBuilder() + .setPreAuthorizeServer(true) + .setSecurityPolicy(securityPolicy) + .build(); + transport.start(transportListener).run(); + securityPolicy + .takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS) + .setResult(Status.PERMISSION_DENIED.withDescription("xyzzy")); + Status transportStatus = transportListener.awaitShutdown(); + assertThat(transportStatus.getCode()).isEqualTo(Code.PERMISSION_DENIED); + assertThat(transportStatus.getDescription()).contains("xyzzy"); + transportListener.awaitTermination(); + } + + @Test + public void testAsyncSecurityPolicyCancelledUponExternalTermination() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = new BinderClientTransportBuilder().setSecurityPolicy(securityPolicy).build(); + transport.start(transportListener).run(); + AuthRequest authRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS); + transport.shutdownNow(Status.UNAVAILABLE); // 'authRequest' remains unanswered! + transportListener.awaitShutdown(); + transportListener.awaitTermination(); + assertThat(authRequest.isCancelled()).isTrue(); + } + private static void startAndAwaitReady( - BinderTransport.BinderClientTransport transport, TestTransportListener transportListener) - throws Exception { + BinderClientTransport transport, TestTransportListener transportListener) throws Exception { transport.start(transportListener).run(); transportListener.awaitReady(); } @@ -422,14 +530,14 @@ private static final class TestTransportListener implements ManagedClientTranspo private final SettableFuture isTerminated = SettableFuture.create(); @Override - public void transportShutdown(Status shutdownStatus) { + public void transportShutdown(Status shutdownStatus, DisconnectError disconnectError) { if (!this.shutdownStatus.set(shutdownStatus)) { throw new IllegalStateException("transportShutdown() already called"); } } public Status awaitShutdown() throws Exception { - return shutdownStatus.get(TIMEOUT_SECONDS, TimeUnit.SECONDS); + return shutdownStatus.get(TIMEOUT_SECONDS, SECONDS); } @Override @@ -440,7 +548,7 @@ public void transportTerminated() { } public void awaitTermination() throws Exception { - isTerminated.get(TIMEOUT_SECONDS, TimeUnit.SECONDS); + isTerminated.get(TIMEOUT_SECONDS, SECONDS); } @Override @@ -451,7 +559,7 @@ public void transportReady() { } public void awaitReady() throws Exception { - isReady.get(TIMEOUT_SECONDS, TimeUnit.SECONDS); + isReady.get(TIMEOUT_SECONDS, SECONDS); } @Override @@ -567,25 +675,4 @@ public Status checkAuthorization(int uid) { } } } - - /** An AsyncSecurityPolicy that lets a test specify the outcome of checkAuthorizationAsync(). */ - static class SettableAsyncSecurityPolicy extends AsyncSecurityPolicy { - private SettableFuture result = SettableFuture.create(); - - public void clearAuthorizationResult() { - result = SettableFuture.create(); - } - - public boolean setAuthorizationResult(Status status) { - return result.set(status); - } - - public boolean setAuthorizationException(Throwable t) { - return result.setException(t); - } - - public ListenableFuture checkAuthorizationAsync(int uid) { - return Futures.nonCancellationPropagating(result); - } - } } diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java index fc9a383d572..7932cabde89 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java @@ -106,8 +106,7 @@ protected ManagedClientTransport newClientTransport(InternalServer server) { options.setEagAttributes(eagAttrs()); options.setChannelLogger(transportLogger()); - return new BinderTransport.BinderClientTransport( - builder.buildClientTransportFactory(), addr, options); + return new BinderClientTransport(builder.buildClientTransportFactory(), addr, options); } @Test diff --git a/binder/src/main/AndroidManifest.xml b/binder/src/main/AndroidManifest.xml index a30cbbdd6fa..239c3b39b38 100644 --- a/binder/src/main/AndroidManifest.xml +++ b/binder/src/main/AndroidManifest.xml @@ -1,2 +1,11 @@ - - + + + + + + + + + + + \ No newline at end of file diff --git a/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java b/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java index 6c1026e2127..b390c1f0ccd 100644 --- a/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java +++ b/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java @@ -18,10 +18,14 @@ import static android.content.Intent.URI_ANDROID_APP_SCHEME; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import android.content.ComponentName; import android.content.Context; import android.content.Intent; +import android.os.UserHandle; +import com.google.common.base.Objects; +import io.grpc.ExperimentalApi; import java.net.SocketAddress; import javax.annotation.Nullable; @@ -41,18 +45,25 @@ * fields, namely, an action of {@link ApiConstants#ACTION_BIND}, an empty category set and null * type and data URI. * - *

The semantics of {@link #equals(Object)} are the same as {@link Intent#filterEquals(Intent)}. + *

Optionally contains a {@link UserHandle} that must be considered wherever the {@link Intent} + * is evaluated. + * + *

{@link #equals(Object)} uses {@link Intent#filterEquals(Intent)} semantics to compare Intents. */ public final class AndroidComponentAddress extends SocketAddress { private static final long serialVersionUID = 0L; private final Intent bindIntent; // "Explicit", having either a component or package restriction. - protected AndroidComponentAddress(Intent bindIntent) { + @Nullable + private final UserHandle targetUser; // null means the same user that hosts this process. + + private AndroidComponentAddress(Intent bindIntent, @Nullable UserHandle targetUser) { checkArgument( bindIntent.getComponent() != null || bindIntent.getPackage() != null, "'bindIntent' must be explicit. Specify either a package or ComponentName."); this.bindIntent = bindIntent; + this.targetUser = targetUser; } /** @@ -99,7 +110,7 @@ public static AndroidComponentAddress forRemoteComponent( * @throws IllegalArgumentException if 'intent' isn't "explicit" */ public static AndroidComponentAddress forBindIntent(Intent intent) { - return new AndroidComponentAddress(intent.cloneFilter()); + return new AndroidComponentAddress(intent.cloneFilter(), null); } /** @@ -108,7 +119,7 @@ public static AndroidComponentAddress forBindIntent(Intent intent) { */ public static AndroidComponentAddress forComponent(ComponentName component) { return new AndroidComponentAddress( - new Intent(ApiConstants.ACTION_BIND).setComponent(component)); + new Intent(ApiConstants.ACTION_BIND).setComponent(component), null); } /** @@ -141,6 +152,9 @@ public ComponentName getComponent() { /** * Returns this address as an explicit {@link Intent} suitable for passing to {@link * Context#bindService}. + * + *

NB: The returned Intent does not specify a target Android user. If {@link #getTargetUser()} + * is non-null, {@link Context#bindServiceAsUser} should be called instead. */ public Intent asBindIntent() { return bindIntent.cloneFilter(); // Intent is mutable so return a copy. @@ -177,13 +191,92 @@ public int hashCode() { public boolean equals(Object obj) { if (obj instanceof AndroidComponentAddress) { AndroidComponentAddress that = (AndroidComponentAddress) obj; - return bindIntent.filterEquals(that.bindIntent); + return bindIntent.filterEquals(that.bindIntent) + && Objects.equal(this.targetUser, that.targetUser); } return false; } @Override public String toString() { - return "AndroidComponentAddress[" + bindIntent + "]"; + StringBuilder builder = new StringBuilder("AndroidComponentAddress["); + if (targetUser != null) { + builder.append(targetUser); + builder.append("@"); + } + builder.append(bindIntent); + builder.append("]"); + return builder.toString(); + } + + /** + * Identifies the Android user in which the bind Intent will be evaluated. + * + *

Returns the {@link UserHandle}, or null which means that the Android user hosting the + * current process will be used. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10173") + @Nullable + public UserHandle getTargetUser() { + return targetUser; + } + + public static Builder newBuilder() { + return new Builder(); + } + + /** Fluently builds instances of {@link AndroidComponentAddress}. */ + public static class Builder { + Intent bindIntent; + UserHandle targetUser; + + /** + * Sets the binding {@link Intent} to one having the "filter matching" fields of 'intent'. + * + *

'intent' must be "explicit", i.e. having either a target component ({@link + * Intent#getComponent()}) or package restriction ({@link Intent#getPackage()}). + */ + public Builder setBindIntent(Intent intent) { + this.bindIntent = intent.cloneFilter(); + return this; + } + + /** + * Sets the binding {@link Intent} to one with the specified 'component' and default values for + * all other fields, for convenience. + */ + public Builder setBindIntentFromComponent(ComponentName component) { + this.bindIntent = new Intent(ApiConstants.ACTION_BIND).setComponent(component); + return this; + } + + /** + * Specifies the Android user in which the built Address' bind Intent will be evaluated. + * + *

Connecting to a server in a different Android user is uncommon and requires the client app + * have runtime visibility of @SystemApi's and hold certain @SystemApi permissions. + * The device must also be running Android SDK version 30 or higher. + * + *

See https://developer.android.com/guide/app-compatibility/restrictions-non-sdk-interfaces + * for details on which apps can call the underlying @SystemApi's needed to make this type + * of connection. + * + *

One of the "android.permission.INTERACT_ACROSS_XXX" permissions is required. The exact one + * depends on the calling user's relationship to the target user, whether client and server are + * in the same or different apps, and the version of Android in use. See {@link + * Context#bindServiceAsUser}, the essential underlying Android API, for details. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10173") + public Builder setTargetUser(@Nullable UserHandle targetUser) { + this.targetUser = targetUser; + return this; + } + + public AndroidComponentAddress build() { + // We clone any incoming mutable intent in the setter, not here. AndroidComponentAddress + // itself is immutable so multiple instances built from here can safely share 'bindIntent'. + checkState(bindIntent != null, "Required property 'bindIntent' unset"); + return new AndroidComponentAddress(bindIntent, targetUser); + } } } diff --git a/binder/src/main/java/io/grpc/binder/ApiConstants.java b/binder/src/main/java/io/grpc/binder/ApiConstants.java index 43e94338fdc..fbf4be6b7ce 100644 --- a/binder/src/main/java/io/grpc/binder/ApiConstants.java +++ b/binder/src/main/java/io/grpc/binder/ApiConstants.java @@ -17,7 +17,11 @@ package io.grpc.binder; import android.content.Intent; +import android.os.UserHandle; +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; +import io.grpc.NameResolver; /** Constant parts of the gRPC binder transport public API. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") @@ -29,4 +33,43 @@ private ApiConstants() {} * themselves in a {@link android.app.Service#onBind(Intent)} call. */ public static final String ACTION_BIND = "grpc.io.action.BIND"; + + /** + * Gives a {@link NameResolver} access to its Channel's "source" {@link android.content.Context}, + * the entry point to almost every other Android API. + * + *

This argument is set automatically by {@link BinderChannelBuilder}. Any value passed to + * {@link io.grpc.ManagedChannelBuilder#setNameResolverArg} will be ignored. + * + *

See {@link BinderChannelBuilder#forTarget(String, android.content.Context)} for more. + */ + public static final NameResolver.Args.Key SOURCE_ANDROID_CONTEXT = + NameResolver.Args.Key.create("source-android-context"); + + /** + * Specifies the Android user in which target URIs should be resolved. + * + *

{@link UserHandle} can't reasonably be encoded in a target URI string. Instead, all {@link + * io.grpc.NameResolverProvider}s producing {@link AndroidComponentAddress}es should let clients + * address servers in another Android user using this argument. + * + *

Connecting to a server in a different Android user is uncommon and can only be done by a + * "system app" client with special permissions. See {@link + * AndroidComponentAddress.Builder#setTargetUser(UserHandle)} for details. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10173") + public static final NameResolver.Args.Key TARGET_ANDROID_USER = + NameResolver.Args.Key.create("target-android-user"); + + /** + * Lets you override a Channel's pre-auth configuration (see {@link + * BinderChannelBuilder#preAuthorizeServers(boolean)}) for a given {@link EquivalentAddressGroup}. + * + *

A {@link NameResolver} that discovers servers from an untrusted source like PackageManager + * can use this to force server pre-auth and prevent abuse. + */ + @EquivalentAddressGroup.Attr + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12191") + public static final Attributes.Key PRE_AUTH_SERVER_OVERRIDE = + Attributes.Key.create("pre-auth-server-override"); } diff --git a/binder/src/main/java/io/grpc/binder/AsyncSecurityPolicy.java b/binder/src/main/java/io/grpc/binder/AsyncSecurityPolicy.java index 2a37e6fd517..9594c644e0c 100644 --- a/binder/src/main/java/io/grpc/binder/AsyncSecurityPolicy.java +++ b/binder/src/main/java/io/grpc/binder/AsyncSecurityPolicy.java @@ -17,11 +17,11 @@ package io.grpc.binder; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.ExperimentalApi; import io.grpc.Status; import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; -import javax.annotation.CheckReturnValue; /** * Decides whether a given Android UID is authorized to access some resource. @@ -67,4 +67,25 @@ public final Status checkAuthorization(int uid) { * authorized. */ public abstract ListenableFuture checkAuthorizationAsync(int uid); + + /** + * Decides whether the given Android UID is authorized, without providing its raw integer value. + * + *

Calling this is equivalent to calling {@link SecurityPolicy#checkAuthorization(int)}, except + * the caller provides a {@link PeerUid} wrapper instead of the raw integer uid (known only to the + * transport). This allows a server to check additional application-layer security policy for + * itself *after* the call itself is authorized by the transport layer. Cross cutting application- + * layer checks could be done from a {@link io.grpc.ServerInterceptor}. Checks based on the + * substance of a request message could be done by the individual RPC method implementations + * themselves. + * + *

See #checkAuthorizationAsync(int) for details on the semantics. See {@link + * PeerUids#newPeerIdentifyingServerInterceptor()} for how to get a {@link PeerUid}. + * + * @param uid The Android UID to authenticate. + * @return A gRPC {@link Status} object, with OK indicating authorized. + */ + public final ListenableFuture checkAuthorizationAsync(PeerUid uid) { + return checkAuthorizationAsync(uid.getUid()); + } } diff --git a/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java b/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java index d054c8d8ba6..a241634dd22 100644 --- a/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java +++ b/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java @@ -20,8 +20,6 @@ import static com.google.common.base.Preconditions.checkState; import android.content.Context; -import android.os.UserHandle; -import androidx.annotation.RequiresApi; import com.google.errorprone.annotations.DoNotCall; import io.grpc.ExperimentalApi; import io.grpc.ForwardingChannelBuilder; @@ -235,24 +233,6 @@ public BinderChannelBuilder securityPolicy(SecurityPolicy securityPolicy) { return this; } - /** - * Provides the target {@UserHandle} of the remote Android service. - * - *

When targetUserHandle is set, Context.bindServiceAsUser will used and additional Android - * permissions will be required. If your usage does not require cross-user communications, please - * do not set this field. It is the caller's responsibility to make sure that it holds the - * corresponding permissions. - * - * @param targetUserHandle the target user to bind into. - * @return this - */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10173") - @RequiresApi(30) - public BinderChannelBuilder bindAsUser(UserHandle targetUserHandle) { - transportFactoryBuilder.setTargetUserHandle(targetUserHandle); - return this; - } - /** Sets the policy for inbound parcelable objects. */ public BinderChannelBuilder inboundParcelablePolicy( InboundParcelablePolicy inboundParcelablePolicy) { @@ -271,6 +251,96 @@ public BinderChannelBuilder strictLifecycleManagement() { return this; } + /** + * Checks servers against this Channel's {@link SecurityPolicy} *before* binding. + * + *

Android users can be tricked into installing a malicious app with the same package name as a + * legitimate server. That's why we don't send calls to a server until it has been authorized by + * an appropriate {@link SecurityPolicy}. But merely binding to a malicious server can enable + * "keep-alive" and "background activity launch" abuse, even if it's ultimately unauthorized. + * Pre-authorization mitigates these threats by performing a preliminary {@link SecurityPolicy} + * check against a server app's PackageManager-registered identity without actually creating an + * instance of it. This is especially important for security when the server's direct address + * isn't known in advance but rather resolved via target URI or discovered by other means. + * + *

Note that, unlike ordinary authorization, pre-authorization is performed against the server + * app's UID, not the UID of the process hosting the bound Service. These can be different, most + * commonly due to services that set `android:isolatedProcess=true`. + * + *

Pre-authorization is strongly recommended but it remains optional for now because of this + * behavior change and the small performance cost. + * + *

The default value of this property is false but it will become true in a future release. + * Clients that require a particular behavior should configure it explicitly using this method + * rather than relying on the default. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12191") + public BinderChannelBuilder preAuthorizeServers(boolean preAuthorize) { + transportFactoryBuilder.setPreAuthorizeServers(preAuthorize); + return this; + } + + /** + * Specifies how and when to authorize a server against this Channel's {@link SecurityPolicy}. + * + *

This method selects the original "legacy" authorization strategy, which is no longer + * preferred for two reasons: First, the legacy strategy considers the UID of the server *process* + * we connect to. This is problematic for services using the `android:isolatedProcess` attribute, + * which runs them under a different "ephemeral" UID. This UID lacks all the privileges of the + * hosting app -- any non-trivial SecurityPolicy would fail to authorize it. Second, the legacy + * authorization strategy performs SecurityPolicy checks later in the connection handshake, which + * means the calling UID must be rechecked on every subsequent RPC. For these reasons, prefer + * {@link #useV2AuthStrategy} instead. + * + *

The server does not know which authorization strategy a client is using. Both strategies + * work with all versions of the grpc-binder server. + * + *

Callers need not specify an authorization strategy, but the default is unspecified and will + * eventually become {@link #useV2AuthStrategy()}. Clients that require the legacy strategy should + * configure it explicitly using this method. Eventually, however, legacy support will be + * deprecated and removed. + * + * @return this + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12397") + public BinderChannelBuilder useLegacyAuthStrategy() { + transportFactoryBuilder.setUseLegacyAuthStrategy(true); + return this; + } + + /** + * Specifies how and when to authorize a server against this Channel's {@link SecurityPolicy}. + * + *

This method selects the v2 authorization strategy. It improves on the original strategy + * ({@link #useLegacyAuthStrategy}), by considering the UID of the server *app* we connect to, + * rather than the server *process*. This allows clients to connect to services configured with + * the `android:isolatedProcess` attribute, which run with the same authority as the hosting app, + * but under a different "ephemeral" UID that any non-trivial SecurityPolicy would fail to + * authorize. + * + *

Furthermore, the v2 authorization strategy performs SecurityPolicy checks earlier in the + * connection handshake, which allows subsequent RPCs over that connection to proceed securely + * without further UID checks. For these reasons, clients should prefer the v2 strategy. + * + *

The server does not know which authorization strategy a client is using. Both strategies + * work with all versions of the grpc-binder server. + * + *

Callers need not specify an authorization strategy, but the default is unspecified and can + * change over time. Clients that require the v2 strategy should configure it explicitly using + * this method. Eventually, this strategy will become the default and legacy support will be + * removed. + * + *

If moving to the new authorization strategy causes a robolectric test to fail, ensure your + * fake Service component is registered with `ShadowPackageManager` using `addOrUpdateService()`. + * + * @return this + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12397") + public BinderChannelBuilder useV2AuthStrategy() { + transportFactoryBuilder.setUseLegacyAuthStrategy(false); + return this; + } + @Override public BinderChannelBuilder idleTimeout(long value, TimeUnit unit) { checkState( @@ -284,6 +354,8 @@ public BinderChannelBuilder idleTimeout(long value, TimeUnit unit) { public ManagedChannel build() { transportFactoryBuilder.setOffloadExecutorPool( managedChannelImplBuilder.getOffloadExecutorPool()); + setNameResolverArg( + ApiConstants.SOURCE_ANDROID_CONTEXT, transportFactoryBuilder.getSourceContext()); return super.build(); } } diff --git a/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java b/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java index c926c853472..5f0885883a5 100644 --- a/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java +++ b/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java @@ -68,7 +68,7 @@ private BinderServerBuilder( serverImplBuilder = new ServerImplBuilder( - streamTracerFactories -> { + (streamTracerFactories, metricRecorder) -> { internalBuilder.setStreamTracerFactories(streamTracerFactories); BinderServer server = internalBuilder.build(); BinderInternal.setIBinder(binderReceiver, server.getHostBinder()); diff --git a/binder/src/main/java/io/grpc/binder/SecurityPolicies.java b/binder/src/main/java/io/grpc/binder/SecurityPolicies.java index 05e8c43da79..c0f6fe81989 100644 --- a/binder/src/main/java/io/grpc/binder/SecurityPolicies.java +++ b/binder/src/main/java/io/grpc/binder/SecurityPolicies.java @@ -184,7 +184,6 @@ public Status checkAuthorization(int uid) { * Creates {@link SecurityPolicy} which checks if the app is a device owner app. See {@link * DevicePolicyManager}. */ - @RequiresApi(18) public static io.grpc.binder.SecurityPolicy isDeviceOwner(Context applicationContext) { DevicePolicyManager devicePolicyManager = (DevicePolicyManager) applicationContext.getSystemService(Context.DEVICE_POLICY_SERVICE); @@ -199,7 +198,6 @@ public static io.grpc.binder.SecurityPolicy isDeviceOwner(Context applicationCon * Creates {@link SecurityPolicy} which checks if the app is a profile owner app. See {@link * DevicePolicyManager}. */ - @RequiresApi(21) public static SecurityPolicy isProfileOwner(Context applicationContext) { DevicePolicyManager devicePolicyManager = (DevicePolicyManager) applicationContext.getSystemService(Context.DEVICE_POLICY_SERVICE); diff --git a/binder/src/main/java/io/grpc/binder/SecurityPolicy.java b/binder/src/main/java/io/grpc/binder/SecurityPolicy.java index e539f17e394..3ad8903407f 100644 --- a/binder/src/main/java/io/grpc/binder/SecurityPolicy.java +++ b/binder/src/main/java/io/grpc/binder/SecurityPolicy.java @@ -16,8 +16,8 @@ package io.grpc.binder; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Status; -import javax.annotation.CheckReturnValue; /** * Decides whether a given Android UID is authorized to access some resource. @@ -53,4 +53,25 @@ protected SecurityPolicy() {} * @return A gRPC {@link Status} object, with OK indicating authorized. */ public abstract Status checkAuthorization(int uid); + + /** + * Decides whether the given Android UID is authorized, without providing its raw integer value. + * + *

Calling this is equivalent to calling {@link SecurityPolicy#checkAuthorization(int)}, except + * the caller provides a {@link PeerUid} wrapper instead of the raw integer uid (known only to the + * transport). This allows a server to check additional application-layer security policy for + * itself *after* the call itself is authorized by the transport layer. Cross cutting application- + * layer checks could be done from a {@link io.grpc.ServerInterceptor}. Checks based on the + * substance of a request message could be done by the individual RPC method implementations + * themselves. + * + *

See #checkAuthorizationAsync(int) for details on the semantics. See {@link + * PeerUids#newPeerIdentifyingServerInterceptor()} for how to get a {@link PeerUid}. + * + * @param uid The Android UID to authenticate. + * @return A gRPC {@link Status} object, with OK indicating authorized. + */ + public final Status checkAuthorization(PeerUid uid) { + return checkAuthorization(uid.getUid()); + } } diff --git a/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java b/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java index 6a9361c0eaf..4786a5e6cc4 100644 --- a/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java +++ b/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java @@ -19,10 +19,10 @@ import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Status; import java.util.HashMap; import java.util.Map; -import javax.annotation.CheckReturnValue; /** * A security policy for a gRPC server. diff --git a/binder/src/main/java/io/grpc/binder/UntrustedSecurityPolicies.java b/binder/src/main/java/io/grpc/binder/UntrustedSecurityPolicies.java index 64d8ac1426a..44612a82109 100644 --- a/binder/src/main/java/io/grpc/binder/UntrustedSecurityPolicies.java +++ b/binder/src/main/java/io/grpc/binder/UntrustedSecurityPolicies.java @@ -16,9 +16,9 @@ package io.grpc.binder; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.ExperimentalApi; import io.grpc.Status; -import javax.annotation.CheckReturnValue; /** Static factory methods for creating untrusted security policies. */ @CheckReturnValue diff --git a/binder/src/main/java/io/grpc/binder/internal/ActiveTransportTracker.java b/binder/src/main/java/io/grpc/binder/internal/ActiveTransportTracker.java index ad410186486..01505bfd509 100644 --- a/binder/src/main/java/io/grpc/binder/internal/ActiveTransportTracker.java +++ b/binder/src/main/java/io/grpc/binder/internal/ActiveTransportTracker.java @@ -2,17 +2,17 @@ import static com.google.common.base.Preconditions.checkState; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.Metadata; import io.grpc.internal.ServerListener; import io.grpc.internal.ServerStream; import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransportListener; -import javax.annotation.concurrent.GuardedBy; /** - * Tracks which {@link BinderTransport.BinderServerTransport} are currently active and allows - * invoking a {@link Runnable} only once all transports are terminated. + * Tracks which {@link BinderServerTransport} are currently active and allows invoking a {@link + * Runnable} only once all transports are terminated. */ final class ActiveTransportTracker implements ServerListener { private final ServerListener delegate; diff --git a/binder/src/main/java/io/grpc/binder/internal/Bindable.java b/binder/src/main/java/io/grpc/binder/internal/Bindable.java index 8e1af64b63d..59a2502de2b 100644 --- a/binder/src/main/java/io/grpc/binder/internal/Bindable.java +++ b/binder/src/main/java/io/grpc/binder/internal/Bindable.java @@ -16,10 +16,12 @@ package io.grpc.binder.internal; +import android.content.pm.ServiceInfo; import android.os.IBinder; import androidx.annotation.AnyThread; import androidx.annotation.MainThread; import io.grpc.Status; +import io.grpc.StatusException; /** An interface for managing a {@code Binder} connection. */ interface Bindable { @@ -45,6 +47,22 @@ interface Observer { void onUnbound(Status reason); } + /** + * Fetches details about the remote Service from PackageManager without binding to it. + * + *

Resolving an untrusted address before binding to it lets you screen out problematic servers + * before giving them a chance to run. However, note that the identity/existence of the resolved + * Service can change between the time this method returns and the time you actually bind/connect + * to it. For example, suppose the target package gets uninstalled or upgraded right after this + * method returns. + * + *

Compare with {@link #getConnectedServiceInfo()}, which can only be called after {@link + * Observer#onBound(IBinder)} but can be used to learn about the service you actually connected + * to. + */ + @AnyThread + ServiceInfo resolve() throws StatusException; + /** * Attempt to bind with the remote service. * @@ -53,6 +71,21 @@ interface Observer { @AnyThread void bind(); + /** + * Asks PackageManager for details about the remote Service we *actually* connected to. + * + *

Can only be called after {@link Observer#onBound}. + * + *

Compare with {@link #resolve()}, which reports which service would be selected as of now but + * *without* connecting. + * + * @throws StatusException UNIMPLEMENTED if the connected service isn't found (an {@link + * Observer#onUnbound} callback has likely already happened or is on its way!) + * @throws IllegalStateException if {@link Observer#onBound} has not "happened-before" this call + */ + @AnyThread + ServiceInfo getConnectedServiceInfo() throws StatusException; + /** * Unbind from the remote service if connected. * diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java new file mode 100644 index 00000000000..bef1eefd43e --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java @@ -0,0 +1,547 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.binder.ApiConstants.PRE_AUTH_SERVER_OVERRIDE; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import android.content.Context; +import android.content.pm.ServiceInfo; +import android.os.Binder; +import android.os.IBinder; +import android.os.Parcel; +import android.os.Process; +import androidx.annotation.BinderThread; +import androidx.annotation.MainThread; +import com.google.common.base.Ticker; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.CheckReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; +import io.grpc.Grpc; +import io.grpc.Internal; +import io.grpc.InternalLogId; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.SecurityLevel; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.AsyncSecurityPolicy; +import io.grpc.binder.InboundParcelablePolicy; +import io.grpc.binder.SecurityPolicy; +import io.grpc.internal.ClientStream; +import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; +import io.grpc.internal.ConnectionClientTransport; +import io.grpc.internal.FailingClientStream; +import io.grpc.internal.GrpcAttributes; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SimpleDisconnectError; +import io.grpc.internal.StatsTraceContext; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; +import javax.annotation.concurrent.ThreadSafe; + +/** Concrete client-side transport implementation. */ +@ThreadSafe +@Internal +public final class BinderClientTransport extends BinderTransport + implements ConnectionClientTransport, Bindable.Observer { + + private final ObjectPool offloadExecutorPool; + private final Executor offloadExecutor; + private final SecurityPolicy securityPolicy; + private final Bindable serviceBinding; + + @GuardedBy("this") + private final ClientHandshake handshake; + + /** Number of ongoing calls which keep this transport "in-use". */ + private final AtomicInteger numInUseStreams; + + private final long readyTimeoutMillis; + private final PingTracker pingTracker; + private final boolean preAuthorizeServer; + + @Nullable private ManagedClientTransport.Listener clientTransportListener; + + @GuardedBy("this") + private int latestCallId = FIRST_CALL_ID; + + @GuardedBy("this") + private ScheduledFuture readyTimeoutFuture; // != null iff timeout scheduled. + + /** + * Constructs a new transport instance. + * + * @param factory parameters common to all a Channel's transports + * @param targetAddress the fully resolved and load-balanced server address + * @param options other parameters that can vary as transports come and go within a Channel + */ + public BinderClientTransport( + BinderClientTransportFactory factory, + AndroidComponentAddress targetAddress, + ClientTransportOptions options) { + super( + factory.scheduledExecutorPool, + buildClientAttributes( + options.getEagAttributes(), + factory.sourceContext, + targetAddress, + factory.inboundParcelablePolicy), + factory.binderDecorator, + buildLogId(factory.sourceContext, targetAddress)); + this.offloadExecutorPool = factory.offloadExecutorPool; + this.securityPolicy = factory.securityPolicy; + this.offloadExecutor = offloadExecutorPool.getObject(); + this.readyTimeoutMillis = factory.readyTimeoutMillis; + Boolean preAuthServerOverride = options.getEagAttributes().get(PRE_AUTH_SERVER_OVERRIDE); + this.preAuthorizeServer = + preAuthServerOverride != null ? preAuthServerOverride : factory.preAuthorizeServers; + this.handshake = + factory.useLegacyAuthStrategy ? new LegacyClientHandshake() : new V2ClientHandshake(); + numInUseStreams = new AtomicInteger(); + pingTracker = new PingTracker(Ticker.systemTicker(), (id) -> sendPing(id)); + serviceBinding = + new ServiceBinding( + factory.mainThreadExecutor, + factory.sourceContext, + factory.channelCredentials, + targetAddress.asBindIntent(), + targetAddress.getTargetUser(), + factory.bindServiceFlags.toInteger(), + this); + } + + @Override + void releaseExecutors() { + super.releaseExecutors(); + offloadExecutorPool.returnObject(offloadExecutor); + } + + @Override + public synchronized void onBound(IBinder binder) { + handshake.onBound(binderDecorator.decorate(OneWayBinderProxy.wrap(binder, offloadExecutor))); + } + + @Override + public synchronized void onUnbound(Status reason) { + shutdownInternal(reason, true); + } + + @CheckReturnValue + @Override + public synchronized Runnable start(Listener clientTransportListener) { + this.clientTransportListener = checkNotNull(clientTransportListener); + return this::postStartRunnable; + } + + private synchronized void postStartRunnable() { + if (!inState(TransportState.NOT_STARTED)) { + return; + } + + setState(TransportState.SETUP); + + try { + if (preAuthorizeServer) { + preAuthorize(serviceBinding.resolve()); + } else { + serviceBinding.bind(); + } + } catch (StatusException e) { + shutdownInternal(e.getStatus(), true); + return; + } + + if (readyTimeoutMillis >= 0) { + readyTimeoutFuture = + getScheduledExecutorService() + .schedule( + BinderClientTransport.this::onReadyTimeout, readyTimeoutMillis, MILLISECONDS); + } + } + + @GuardedBy("this") + private void preAuthorize(ServiceInfo serviceInfo) { + // It's unlikely, but the identity/existence of this Service could change by the time we + // actually connect. It doesn't matter though, because: + // - If pre-auth fails (but would succeed against the server's new state), the grpc-core layer + // will eventually retry using a new transport instance that will see the Service's new state. + // - If pre-auth succeeds (but would fail against the server's new state), we might give an + // unauthorized server a chance to run, but the connection will still fail by SecurityPolicy + // check later in handshake. Pre-auth remains effective at mitigating abuse because malware + // can't typically control the exact timing of its installation. + ListenableFuture preAuthResultFuture = + register(checkServerAuthorizationAsync(serviceInfo.applicationInfo.uid)); + Futures.addCallback( + preAuthResultFuture, + new FutureCallback() { + @Override + public void onSuccess(Status result) { + handlePreAuthResult(result); + } + + @Override + public void onFailure(Throwable t) { + handleAuthResult(t); + } + }, + offloadExecutor); + } + + private synchronized void handlePreAuthResult(Status authorization) { + if (!inState(TransportState.SETUP)) { + return; + } + + if (!authorization.isOk()) { + shutdownInternal(authorization, true); + return; + } + + serviceBinding.bind(); + } + + private synchronized void onReadyTimeout() { + if (inState(TransportState.SETUP)) { + readyTimeoutFuture = null; + shutdownInternal( + Status.DEADLINE_EXCEEDED.withDescription( + "Connect timeout " + readyTimeoutMillis + "ms lapsed"), + true); + } + } + + @Override + public synchronized ClientStream newStream( + final MethodDescriptor method, + final Metadata headers, + final CallOptions callOptions, + ClientStreamTracer[] tracers) { + if (!inState(TransportState.READY)) { + return newFailingClientStream( + isShutdown() + ? shutdownStatus + : Status.INTERNAL.withDescription("newStream() before transportReady()"), + attributes, + headers, + tracers); + } + + int callId = latestCallId++; + if (latestCallId == LAST_CALL_ID) { + latestCallId = FIRST_CALL_ID; + } + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, attributes, headers); + Inbound.ClientInbound inbound = + new Inbound.ClientInbound( + this, attributes, callId, GrpcUtil.shouldBeCountedForInUse(callOptions)); + if (ongoingCalls.putIfAbsent(callId, inbound) != null) { + Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); + shutdownInternal(failure, true); + return newFailingClientStream(failure, attributes, headers, tracers); + } + + if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { + clientTransportListener.transportInUse(true); + } + Outbound.ClientOutbound outbound = + new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); + if (method.getType().clientSendsOneMessage()) { + return new SingleMessageClientStream(inbound, outbound, attributes); + } else { + return new MultiMessageClientStream(inbound, outbound, attributes); + } + } + + @Override + protected void unregisterInbound(Inbound inbound) { + if (inbound.countsForInUse() && numInUseStreams.decrementAndGet() == 0) { + clientTransportListener.transportInUse(false); + } + super.unregisterInbound(inbound); + } + + @Override + public void ping(final PingCallback callback, Executor executor) { + pingTracker.startPing(callback, executor); + } + + @Override + public synchronized void shutdown(Status reason) { + checkNotNull(reason, "reason"); + shutdownInternal(reason, false); + } + + @Override + public synchronized void shutdownNow(Status reason) { + checkNotNull(reason, "reason"); + shutdownInternal(reason, true); + } + + @Override + @GuardedBy("this") + void notifyShutdown(Status status) { + clientTransportListener.transportShutdown(status, SimpleDisconnectError.UNKNOWN); + } + + @Override + @GuardedBy("this") + void notifyTerminated() { + if (numInUseStreams.getAndSet(0) > 0) { + clientTransportListener.transportInUse(false); + } + if (readyTimeoutFuture != null) { + readyTimeoutFuture.cancel(false); + readyTimeoutFuture = null; + } + serviceBinding.unbind(); + clientTransportListener.transportTerminated(); + } + + @Override + @GuardedBy("this") + protected void handleSetupTransport(Parcel parcel) { + if (!inState(TransportState.SETUP)) { + return; + } + + int version = parcel.readInt(); + if (version != WIRE_FORMAT_VERSION) { + shutdownInternal(Status.UNAVAILABLE.withDescription("Wire format version mismatch"), true); + return; + } + + IBinder binder = parcel.readStrongBinder(); + if (binder == null) { + shutdownInternal(Status.UNAVAILABLE.withDescription("Malformed SETUP_TRANSPORT data"), true); + return; + } + + if (!setOutgoingBinder(OneWayBinderProxy.wrap(binder, offloadExecutor))) { + shutdownInternal( + Status.UNAVAILABLE.withDescription("Failed to observe outgoing binder"), true); + return; + } + handshake.handleSetupTransport(); + } + + @GuardedBy("this") + private void checkServerAuthorization(int remoteUid) { + ListenableFuture authResultFuture = register(checkServerAuthorizationAsync(remoteUid)); + Futures.addCallback( + authResultFuture, + new FutureCallback() { + @Override + public void onSuccess(Status result) { + handleAuthResult(result); + } + + @Override + public void onFailure(Throwable t) { + handleAuthResult(t); + } + }, + offloadExecutor); + } + + private ListenableFuture checkServerAuthorizationAsync(int remoteUid) { + return (securityPolicy instanceof AsyncSecurityPolicy) + ? ((AsyncSecurityPolicy) securityPolicy).checkAuthorizationAsync(remoteUid) + : Futures.submit(() -> securityPolicy.checkAuthorization(remoteUid), offloadExecutor); + } + + private synchronized void handleAuthResult(Status authorization) { + if (!inState(TransportState.SETUP)) { + return; + } + + if (!authorization.isOk()) { + shutdownInternal(authorization, true); + return; + } + handshake.onServerAuthorizationOk(); + } + + private final class V2ClientHandshake implements ClientHandshake { + + private OneWayBinderProxy endpointBinder; + + @Override + @GuardedBy("BinderClientTransport.this") // By way of @GuardedBy("this") `handshake` member. + public void onBound(OneWayBinderProxy endpointBinder) { + this.endpointBinder = endpointBinder; + Futures.addCallback( + Futures.submit(serviceBinding::getConnectedServiceInfo, offloadExecutor), + new FutureCallback() { + @Override + public void onSuccess(ServiceInfo result) { + synchronized (BinderClientTransport.this) { + onConnectedServiceInfo(result); + } + } + + @Override + public void onFailure(Throwable t) { + synchronized (BinderClientTransport.this) { + shutdownInternal(Status.fromThrowable(t), true); + } + } + }, + offloadExecutor); + } + + @GuardedBy("BinderClientTransport.this") + private void onConnectedServiceInfo(ServiceInfo serviceInfo) { + if (!inState(TransportState.SETUP)) { + return; + } + attributes = setSecurityAttrs(attributes, serviceInfo.applicationInfo.uid); + checkServerAuthorization(serviceInfo.applicationInfo.uid); + } + + @Override + @GuardedBy("BinderClientTransport.this") + public void onServerAuthorizationOk() { + sendSetupTransaction(endpointBinder); + } + + @Override + @GuardedBy("BinderClientTransport.this") // By way of @GuardedBy("this") `handshake` member. + public void handleSetupTransport() { + onHandshakeComplete(); + } + } + + @GuardedBy("this") + private void onHandshakeComplete() { + setState(TransportState.READY); + attributes = clientTransportListener.filterTransport(attributes); + clientTransportListener.transportReady(); + if (readyTimeoutFuture != null) { + readyTimeoutFuture.cancel(false); + readyTimeoutFuture = null; + } + } + + private synchronized void handleAuthResult(Throwable t) { + shutdownInternal( + Status.INTERNAL.withDescription("Could not evaluate SecurityPolicy").withCause(t), true); + } + + @GuardedBy("this") + @Override + protected void handlePingResponse(Parcel parcel) { + pingTracker.onPingResponse(parcel.readInt()); + } + + /** + * An abstract implementation of the client's connection handshake. + * + *

Supports a clean migration away from the legacy approach, one client at a time. + */ + private interface ClientHandshake { + /** + * Notifies the implementation that the binding has succeeded and we are now connected to the + * server's "endpoint" which can be reached at 'endpointBinder'. + */ + @MainThread + void onBound(OneWayBinderProxy endpointBinder); + + /** Notifies the implementation that we've received a valid SETUP_TRANSPORT transaction. */ + @BinderThread + void handleSetupTransport(); + + /** Notifies the implementation that the SecurityPolicy check of the server succeeded. */ + void onServerAuthorizationOk(); + } + + private final class LegacyClientHandshake implements ClientHandshake { + @Override + @MainThread + @GuardedBy("BinderClientTransport.this") // By way of @GuardedBy("this") `handshake` member. + public void onBound(OneWayBinderProxy binder) { + sendSetupTransaction(binder); + } + + @Override + @BinderThread + @GuardedBy("BinderClientTransport.this") // By way of @GuardedBy("this") `handshake` member. + public void handleSetupTransport() { + int remoteUid = Binder.getCallingUid(); + restrictIncomingBinderToCallsFrom(remoteUid); + attributes = setSecurityAttrs(attributes, remoteUid); + checkServerAuthorization(remoteUid); + } + + @Override + @GuardedBy("BinderClientTransport.this") // By way of @GuardedBy("this") `handshake` member. + public void onServerAuthorizationOk() { + onHandshakeComplete(); + } + } + + private static ClientStream newFailingClientStream( + Status failure, Attributes attributes, Metadata headers, ClientStreamTracer[] tracers) { + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, attributes, headers); + statsTraceContext.clientOutboundHeaders(); + return new FailingClientStream(failure, tracers); + } + + private static InternalLogId buildLogId( + Context sourceContext, AndroidComponentAddress targetAddress) { + return InternalLogId.allocate( + BinderClientTransport.class, + sourceContext.getClass().getSimpleName() + "->" + targetAddress); + } + + private static Attributes buildClientAttributes( + Attributes eagAttrs, + Context sourceContext, + AndroidComponentAddress targetAddress, + InboundParcelablePolicy inboundParcelablePolicy) { + return Attributes.newBuilder() + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE) // Trust noone for now. + .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs) + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, AndroidComponentAddress.forContext(sourceContext)) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, targetAddress) + .set(INBOUND_PARCELABLE_POLICY, inboundParcelablePolicy) + .build(); + } + + private static Attributes setSecurityAttrs(Attributes attributes, int uid) { + return attributes.toBuilder() + .set(REMOTE_UID, uid) + .set( + GrpcAttributes.ATTR_SECURITY_LEVEL, + uid == Process.myUid() + ? SecurityLevel.PRIVACY_AND_INTEGRITY + : SecurityLevel.INTEGRITY) // TODO: Have the SecrityPolicy decide this. + .build(); + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderClientTransportFactory.java b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransportFactory.java index 1e2b80b2fdb..459e064ad9b 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderClientTransportFactory.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransportFactory.java @@ -18,7 +18,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import android.content.Context; -import android.os.UserHandle; import androidx.core.content.ContextCompat; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; @@ -39,7 +38,6 @@ import java.util.Collections; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; /** Creates new binder transports. */ @Internal @@ -50,11 +48,12 @@ public final class BinderClientTransportFactory implements ClientTransportFactor final ObjectPool scheduledExecutorPool; final ObjectPool offloadExecutorPool; final SecurityPolicy securityPolicy; - @Nullable final UserHandle targetUserHandle; final BindServiceFlags bindServiceFlags; final InboundParcelablePolicy inboundParcelablePolicy; final OneWayBinderProxy.Decorator binderDecorator; final long readyTimeoutMillis; + final boolean preAuthorizeServers; // TODO(jdcormie): Default to true. + final boolean useLegacyAuthStrategy; ScheduledExecutorService executorService; Executor offloadExecutor; @@ -70,23 +69,24 @@ private BinderClientTransportFactory(Builder builder) { scheduledExecutorPool = checkNotNull(builder.scheduledExecutorPool); offloadExecutorPool = checkNotNull(builder.offloadExecutorPool); securityPolicy = checkNotNull(builder.securityPolicy); - targetUserHandle = builder.targetUserHandle; bindServiceFlags = checkNotNull(builder.bindServiceFlags); inboundParcelablePolicy = checkNotNull(builder.inboundParcelablePolicy); binderDecorator = checkNotNull(builder.binderDecorator); readyTimeoutMillis = builder.readyTimeoutMillis; + preAuthorizeServers = builder.preAuthorizeServers; + useLegacyAuthStrategy = builder.useLegacyAuthStrategy; executorService = scheduledExecutorPool.getObject(); offloadExecutor = offloadExecutorPool.getObject(); } @Override - public BinderTransport.BinderClientTransport newClientTransport( + public BinderClientTransport newClientTransport( SocketAddress addr, ClientTransportOptions options, ChannelLogger channelLogger) { if (closed) { throw new IllegalStateException("The transport factory is closed."); } - return new BinderTransport.BinderClientTransport(this, (AndroidComponentAddress) addr, options); + return new BinderClientTransport(this, (AndroidComponentAddress) addr, options); } @Override @@ -123,11 +123,12 @@ public static final class Builder implements ClientTransportFactoryBuilder { ObjectPool scheduledExecutorPool = SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); SecurityPolicy securityPolicy = SecurityPolicies.internalOnly(); - @Nullable UserHandle targetUserHandle; BindServiceFlags bindServiceFlags = BindServiceFlags.DEFAULTS; InboundParcelablePolicy inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; OneWayBinderProxy.Decorator binderDecorator = OneWayBinderProxy.IDENTITY_DECORATOR; long readyTimeoutMillis = 60_000; + boolean preAuthorizeServers; + boolean useLegacyAuthStrategy = true; // TODO(jdcormie): Default to false. @Override public BinderClientTransportFactory buildClientTransportFactory() { @@ -139,6 +140,10 @@ public Builder setSourceContext(Context sourceContext) { return this; } + public Context getSourceContext() { + return sourceContext; + } + public Builder setOffloadExecutorPool(ObjectPool offloadExecutorPool) { this.offloadExecutorPool = checkNotNull(offloadExecutorPool, "offloadExecutorPool"); return this; @@ -165,11 +170,6 @@ public Builder setSecurityPolicy(SecurityPolicy securityPolicy) { return this; } - public Builder setTargetUserHandle(@Nullable UserHandle targetUserHandle) { - this.targetUserHandle = targetUserHandle; - return this; - } - public Builder setBindServiceFlags(BindServiceFlags bindServiceFlags) { this.bindServiceFlags = checkNotNull(bindServiceFlags, "bindServiceFlags"); return this; @@ -216,5 +216,17 @@ public Builder setReadyTimeoutMillis(long readyTimeoutMillis) { this.readyTimeoutMillis = readyTimeoutMillis; return this; } + + /** Whether to check server addresses against the SecurityPolicy *before* binding to them. */ + public Builder setPreAuthorizeServers(boolean preAuthorizeServers) { + this.preAuthorizeServers = preAuthorizeServers; + return this; + } + + /** Specifies which version of the client handshake to use. */ + public Builder setUseLegacyAuthStrategy(boolean useLegacyAuthStrategy) { + this.useLegacyAuthStrategy = useLegacyAuthStrategy; + return this; + } } } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderServer.java b/binder/src/main/java/io/grpc/binder/internal/BinderServer.java index 0ad54fb74d1..96685a2f8bd 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderServer.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderServer.java @@ -25,6 +25,7 @@ import android.os.Parcel; import android.os.RemoteException; import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.Grpc; import io.grpc.InternalChannelz.SocketStats; @@ -48,7 +49,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -178,14 +178,14 @@ public synchronized boolean handleTransaction(int code, Parcel parcel) { serverPolicyChecker, checkNotNull(executor, "Not started?")); // Create a new transport and let our listener know about it. - BinderTransport.BinderServerTransport transport = - new BinderTransport.BinderServerTransport( + BinderServerTransport transport = + BinderServerTransport.create( executorServicePool, attrsBuilder.build(), streamTracerFactories, OneWayBinderProxy.IDENTITY_DECORATOR, callbackBinder); - transport.setServerTransportListener(listener.transportCreated(transport)); + transport.start(listener.transportCreated(transport)); return true; } } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java new file mode 100644 index 00000000000..b8ab5e9f843 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java @@ -0,0 +1,157 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import android.os.IBinder; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.Internal; +import io.grpc.InternalLogId; +import io.grpc.Metadata; +import io.grpc.ServerStreamTracer; +import io.grpc.Status; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerTransport; +import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.StatsTraceContext; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import javax.annotation.Nullable; + +/** Concrete server-side transport implementation. */ +@Internal +public final class BinderServerTransport extends BinderTransport implements ServerTransport { + + private final List streamTracerFactories; + + @GuardedBy("this") + private final SimplePromise listenerPromise = new SimplePromise<>(); + + private BinderServerTransport( + ObjectPool executorServicePool, + Attributes attributes, + List streamTracerFactories, + OneWayBinderProxy.Decorator binderDecorator) { + super(executorServicePool, attributes, binderDecorator, buildLogId(attributes)); + this.streamTracerFactories = streamTracerFactories; + } + + /** + * Constructs a new transport instance. + * + * @param binderDecorator used to decorate 'callbackBinder', for fault injection. + */ + public static BinderServerTransport create( + ObjectPool executorServicePool, + Attributes attributes, + List streamTracerFactories, + OneWayBinderProxy.Decorator binderDecorator, + IBinder callbackBinder) { + BinderServerTransport transport = + new BinderServerTransport( + executorServicePool, attributes, streamTracerFactories, binderDecorator); + // TODO(jdcormie): Plumb in the Server's executor() and use it here instead. + // No need to handle failure here because if 'callbackBinder' is already dead, we'll notice it + // again in start() when we send the first transaction. + synchronized (transport) { + transport.setOutgoingBinder( + OneWayBinderProxy.wrap(callbackBinder, transport.getScheduledExecutorService())); + } + return transport; + } + + /** + * Initializes this transport instance. + * + *

Must be called exactly once, even if {@link #shutdown} or {@link #shutdownNow} was called + * first. + * + * @param serverTransportListener where this transport will report events + */ + public synchronized void start(ServerTransportListener serverTransportListener) { + this.listenerPromise.set(serverTransportListener); + if (isShutdown()) { + // It's unlikely, but we could be shutdown externally between construction and start(). One + // possible cause is an extremely short handshake timeout. + return; + } + + sendSetupTransaction(); + + // Check we're not shutdown again, since a failure inside sendSetupTransaction (or a callback + // it triggers), could have shut us down. + if (isShutdown()) { + return; + } + + setState(TransportState.READY); + attributes = serverTransportListener.transportReady(attributes); + } + + StatsTraceContext createStatsTraceContext(String methodName, Metadata headers) { + return StatsTraceContext.newServerContext(streamTracerFactories, methodName, headers); + } + + /** + * Reports a new ServerStream requested by the remote client. + * + *

Precondition: {@link #start(ServerTransportListener)} must already have been called. + */ + synchronized Status startStream(ServerStream stream, String methodName, Metadata headers) { + if (isShutdown()) { + return Status.UNAVAILABLE.withDescription("transport is shutdown"); + } + + listenerPromise.get().streamCreated(stream, methodName, headers); + return Status.OK; + } + + @Override + @GuardedBy("this") + void notifyShutdown(Status status) { + // Nothing to do. + } + + @Override + @GuardedBy("this") + void notifyTerminated() { + listenerPromise.runWhenSet(ServerTransportListener::transportTerminated); + } + + @Override + public synchronized void shutdown() { + shutdownInternal(Status.OK, false); + } + + @Override + public synchronized void shutdownNow(Status reason) { + shutdownInternal(reason, true); + } + + @Override + @Nullable + @GuardedBy("this") + protected Inbound createInbound(int callId) { + return new Inbound.ServerInbound(this, attributes, callId); + } + + private static InternalLogId buildLogId(Attributes attributes) { + return InternalLogId.allocate( + BinderServerTransport.class, "from " + attributes.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java index 81eecf975bf..1592f6977df 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java @@ -19,51 +19,29 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateFuture; -import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static io.grpc.binder.internal.TransactionUtils.newCallerFilteringHandler; -import android.content.Context; -import android.os.Binder; import android.os.DeadObjectException; import android.os.IBinder; import android.os.Parcel; -import android.os.Process; import android.os.RemoteException; import android.os.TransactionTooLargeException; +import androidx.annotation.BinderThread; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Ticker; import com.google.common.base.Verify; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; -import io.grpc.CallOptions; -import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.Internal; +import io.grpc.InternalChannelz; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.SecurityLevel; -import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.StatusException; -import io.grpc.binder.AndroidComponentAddress; -import io.grpc.binder.AsyncSecurityPolicy; import io.grpc.binder.InboundParcelablePolicy; -import io.grpc.binder.SecurityPolicy; -import io.grpc.internal.ClientStream; -import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; -import io.grpc.internal.ConnectionClientTransport; -import io.grpc.internal.FailingClientStream; -import io.grpc.internal.GrpcAttributes; -import io.grpc.internal.GrpcUtil; -import io.grpc.internal.ManagedClientTransport; +import io.grpc.binder.internal.LeakSafeOneWayBinder.TransactionHandler; import io.grpc.internal.ObjectPool; -import io.grpc.internal.ServerStream; -import io.grpc.internal.ServerTransport; -import io.grpc.internal.ServerTransportListener; -import io.grpc.internal.StatsTraceContext; import java.util.ArrayList; import java.util.Iterator; import java.util.LinkedHashSet; @@ -71,16 +49,11 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.Executor; +import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -105,8 +78,7 @@ * https://github.com/grpc/proposal/blob/master/L73-java-binderchannel/wireformat.md */ @ThreadSafe -public abstract class BinderTransport - implements LeakSafeOneWayBinder.TransactionHandler, IBinder.DeathRecipient { +public abstract class BinderTransport implements IBinder.DeathRecipient { private static final Logger logger = Logger.getLogger(BinderTransport.class.getName()); @@ -168,10 +140,10 @@ public abstract class BinderTransport private static final int RESERVED_TRANSACTIONS = 1000; /** The first call ID we can use. */ - private static final int FIRST_CALL_ID = IBinder.FIRST_CALL_TRANSACTION + RESERVED_TRANSACTIONS; + static final int FIRST_CALL_ID = IBinder.FIRST_CALL_TRANSACTION + RESERVED_TRANSACTIONS; /** The last call ID we can use. */ - private static final int LAST_CALL_ID = IBinder.LAST_CALL_TRANSACTION; + static final int LAST_CALL_ID = IBinder.LAST_CALL_TRANSACTION; /** The states of this transport. */ protected enum TransportState { @@ -187,6 +159,8 @@ protected enum TransportState { private final ObjectPool executorServicePool; private final ScheduledExecutorService scheduledExecutorService; private final InternalLogId logId; + + @GuardedBy("this") private final LeakSafeOneWayBinder incomingBinder; protected final ConcurrentHashMap> ongoingCalls; @@ -195,6 +169,9 @@ protected enum TransportState { @GuardedBy("this") private final LinkedHashSet callIdsToNotifyWhenReady = new LinkedHashSet<>(); + @GuardedBy("this") + private final List> ownedFutures = new ArrayList<>(); // To cancel upon terminate. + @GuardedBy("this") protected Attributes attributes; @@ -210,12 +187,14 @@ protected enum TransportState { private final FlowController flowController; /** The number of incoming bytes we've received. */ - private final AtomicLong numIncomingBytes; + // Only read/written on @BinderThread. + private long numIncomingBytes; /** The number of incoming bytes we've told our peer we've received. */ + // Only read/written on @BinderThread. private long acknowledgedIncomingBytes; - private BinderTransport( + protected BinderTransport( ObjectPool executorServicePool, Attributes attributes, OneWayBinderProxy.Decorator binderDecorator, @@ -225,10 +204,9 @@ private BinderTransport( this.attributes = attributes; this.logId = logId; scheduledExecutorService = executorServicePool.getObject(); - incomingBinder = new LeakSafeOneWayBinder(this); + incomingBinder = new LeakSafeOneWayBinder(this::handleTransaction); ongoingCalls = new ConcurrentHashMap<>(); flowController = new FlowController(TRANSACTION_BYTES_WINDOW); - numIncomingBytes = new AtomicLong(); } // Override in child class. @@ -238,7 +216,15 @@ public final ScheduledExecutorService getScheduledExecutorService() { // Override in child class. public final ListenableFuture getStats() { - return immediateFuture(null); + Attributes attributes = getAttributes(); + return immediateFuture( + new InternalChannelz.SocketStats( + /* data= */ null, // TODO: Keep track of these stats with TransportTracer or similar. + /* local= */ attributes.get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR), + /* remote= */ attributes.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR), + // TODO: SocketOptions are meaningless for binder but we're still forced to provide one. + new InternalChannelz.SocketOptions.Builder().build(), + /* security= */ null)); } // Override in child class. @@ -269,6 +255,13 @@ void releaseExecutors() { executorServicePool.returnObject(scheduledExecutorService); } + // Registers the specified future for eventual safe cancellation upon shutdown/terminate. + @GuardedBy("this") + protected final > T register(T future) { + ownedFutures.add(future); + return future; + } + @GuardedBy("this") boolean inState(TransportState transportState) { return this.transportState == transportState; @@ -285,6 +278,14 @@ final void setState(TransportState newState) { transportState = newState; } + /** + * Sets the binder to use for sending subsequent transactions to our peer. + * + *

Subclasses should call this as early as possible but not from a constructor. + * + *

Returns true for success, false if the process hosting 'binder' is already dead. Callers are + * responsible for handling this. + */ @GuardedBy("this") protected boolean setOutgoingBinder(OneWayBinderProxy binder) { binder = binderDecorator.decorate(binder); @@ -299,7 +300,10 @@ protected boolean setOutgoingBinder(OneWayBinderProxy binder) { @Override public synchronized void binderDied() { - shutdownInternal(Status.UNAVAILABLE.withDescription("binderDied"), true); + shutdownInternal( + Status.UNAVAILABLE.withDescription( + "Peer process crashed, exited or was killed (binderDied)"), + true); } @GuardedBy("this") @@ -316,6 +320,8 @@ final void shutdownInternal(Status shutdownStatus, boolean forceTerminate) { sendShutdownTransaction(); ArrayList> calls = new ArrayList<>(ongoingCalls.values()); ongoingCalls.clear(); + ArrayList> futuresToCancel = new ArrayList<>(ownedFutures); + ownedFutures.clear(); scheduledExecutorService.execute( () -> { for (Inbound inbound : calls) { @@ -323,6 +329,12 @@ final void shutdownInternal(Status shutdownStatus, boolean forceTerminate) { inbound.closeAbnormal(shutdownStatus); } } + + for (Future future : futuresToCancel) { + // Not holding any locks here just in case some listener runs on a direct Executor. + future.cancel(false); // No effect if already isDone(). + } + synchronized (this) { notifyTerminated(); } @@ -423,8 +435,9 @@ final void sendOutOfBandClose(int callId, Status status) { } } - @Override - public final boolean handleTransaction(int code, Parcel parcel) { + @BinderThread + @VisibleForTesting + final boolean handleTransaction(int code, Parcel parcel) { try { return handleTransactionInternal(code, parcel); } catch (RuntimeException e) { @@ -440,6 +453,7 @@ public final boolean handleTransaction(int code, Parcel parcel) { } } + @BinderThread private boolean handleTransactionInternal(int code, Parcel parcel) { if (code < FIRST_CALL_ID) { synchronized (this) { @@ -483,16 +497,26 @@ private boolean handleTransactionInternal(int code, Parcel parcel) { if (inbound != null) { inbound.handleTransaction(parcel); } - long nib = numIncomingBytes.addAndGet(size); - if ((nib - acknowledgedIncomingBytes) > TRANSACTION_BYTES_WINDOW_FORCE_ACK) { + numIncomingBytes += size; + if ((numIncomingBytes - acknowledgedIncomingBytes) > TRANSACTION_BYTES_WINDOW_FORCE_ACK) { synchronized (this) { - sendAcknowledgeBytes(checkNotNull(outgoingBinder)); + sendAcknowledgeBytes(checkNotNull(outgoingBinder), numIncomingBytes); } + acknowledgedIncomingBytes = numIncomingBytes; } return true; } } + @BinderThread + @GuardedBy("this") + protected void restrictIncomingBinderToCallsFrom(int allowedCallingUid) { + TransactionHandler currentHandler = incomingBinder.getHandler(); + if (currentHandler != null) { + incomingBinder.setHandler(newCallerFilteringHandler(allowedCallingUid, currentHandler)); + } + } + @Nullable @GuardedBy("this") protected Inbound createInbound(int callId) { @@ -519,10 +543,8 @@ private final void handlePing(Parcel requestParcel) { protected void handlePingResponse(Parcel parcel) {} @GuardedBy("this") - private void sendAcknowledgeBytes(OneWayBinderProxy iBinder) { + private void sendAcknowledgeBytes(OneWayBinderProxy iBinder, long n) { // Send a transaction to acknowledge reception of incoming data. - long n = numIncomingBytes.get(); - acknowledgedIncomingBytes = n; try (ParcelHolder parcel = ParcelHolder.obtain()) { parcel.get().writeLong(n); iBinder.transact(ACKNOWLEDGE_BYTES, parcel); @@ -553,412 +575,6 @@ final void handleAcknowledgedBytes(long numBytes) { } } - /** Concrete client-side transport implementation. */ - @ThreadSafe - @Internal - public static final class BinderClientTransport extends BinderTransport - implements ConnectionClientTransport, Bindable.Observer { - - private final ObjectPool offloadExecutorPool; - private final Executor offloadExecutor; - private final SecurityPolicy securityPolicy; - private final Bindable serviceBinding; - - /** Number of ongoing calls which keep this transport "in-use". */ - private final AtomicInteger numInUseStreams; - - private final long readyTimeoutMillis; - private final PingTracker pingTracker; - - @Nullable private ManagedClientTransport.Listener clientTransportListener; - - @GuardedBy("this") - private int latestCallId = FIRST_CALL_ID; - - @GuardedBy("this") - private ScheduledFuture readyTimeoutFuture; // != null iff timeout scheduled. - - /** - * Constructs a new transport instance. - * - * @param factory parameters common to all a Channel's transports - * @param targetAddress the fully resolved and load-balanced server address - * @param options other parameters that can vary as transports come and go within a Channel - */ - public BinderClientTransport( - BinderClientTransportFactory factory, - AndroidComponentAddress targetAddress, - ClientTransportOptions options) { - super( - factory.scheduledExecutorPool, - buildClientAttributes( - options.getEagAttributes(), - factory.sourceContext, - targetAddress, - factory.inboundParcelablePolicy), - factory.binderDecorator, - buildLogId(factory.sourceContext, targetAddress)); - this.offloadExecutorPool = factory.offloadExecutorPool; - this.securityPolicy = factory.securityPolicy; - this.offloadExecutor = offloadExecutorPool.getObject(); - this.readyTimeoutMillis = factory.readyTimeoutMillis; - numInUseStreams = new AtomicInteger(); - pingTracker = new PingTracker(Ticker.systemTicker(), (id) -> sendPing(id)); - - serviceBinding = - new ServiceBinding( - factory.mainThreadExecutor, - factory.sourceContext, - factory.channelCredentials, - targetAddress.asBindIntent(), - factory.targetUserHandle, - factory.bindServiceFlags.toInteger(), - this); - } - - @Override - void releaseExecutors() { - super.releaseExecutors(); - offloadExecutorPool.returnObject(offloadExecutor); - } - - @Override - public synchronized void onBound(IBinder binder) { - sendSetupTransaction( - binderDecorator.decorate(OneWayBinderProxy.wrap(binder, offloadExecutor))); - } - - @Override - public synchronized void onUnbound(Status reason) { - shutdownInternal(reason, true); - } - - @CheckReturnValue - @Override - public synchronized Runnable start(ManagedClientTransport.Listener clientTransportListener) { - this.clientTransportListener = checkNotNull(clientTransportListener); - return () -> { - synchronized (BinderClientTransport.this) { - if (inState(TransportState.NOT_STARTED)) { - setState(TransportState.SETUP); - serviceBinding.bind(); - if (readyTimeoutMillis >= 0) { - readyTimeoutFuture = - getScheduledExecutorService() - .schedule( - BinderClientTransport.this::onReadyTimeout, - readyTimeoutMillis, - MILLISECONDS); - } - } - } - }; - } - - private synchronized void onReadyTimeout() { - if (inState(TransportState.SETUP)) { - readyTimeoutFuture = null; - shutdownInternal( - Status.DEADLINE_EXCEEDED.withDescription( - "Connect timeout " + readyTimeoutMillis + "ms lapsed"), - true); - } - } - - @Override - public synchronized ClientStream newStream( - final MethodDescriptor method, - final Metadata headers, - final CallOptions callOptions, - ClientStreamTracer[] tracers) { - if (!inState(TransportState.READY)) { - return newFailingClientStream( - isShutdown() - ? shutdownStatus - : Status.INTERNAL.withDescription("newStream() before transportReady()"), - attributes, - headers, - tracers); - } - - int callId = latestCallId++; - if (latestCallId == LAST_CALL_ID) { - latestCallId = FIRST_CALL_ID; - } - StatsTraceContext statsTraceContext = - StatsTraceContext.newClientContext(tracers, attributes, headers); - Inbound.ClientInbound inbound = - new Inbound.ClientInbound( - this, attributes, callId, GrpcUtil.shouldBeCountedForInUse(callOptions)); - if (ongoingCalls.putIfAbsent(callId, inbound) != null) { - Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); - shutdownInternal(failure, true); - return newFailingClientStream(failure, attributes, headers, tracers); - } else { - if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { - clientTransportListener.transportInUse(true); - } - Outbound.ClientOutbound outbound = - new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); - if (method.getType().clientSendsOneMessage()) { - return new SingleMessageClientStream(inbound, outbound, attributes); - } else { - return new MultiMessageClientStream(inbound, outbound, attributes); - } - } - } - - @Override - protected void unregisterInbound(Inbound inbound) { - if (inbound.countsForInUse() && numInUseStreams.decrementAndGet() == 0) { - clientTransportListener.transportInUse(false); - } - super.unregisterInbound(inbound); - } - - @Override - public void ping(final PingCallback callback, Executor executor) { - pingTracker.startPing(callback, executor); - } - - @Override - public synchronized void shutdown(Status reason) { - checkNotNull(reason, "reason"); - shutdownInternal(reason, false); - } - - @Override - public synchronized void shutdownNow(Status reason) { - checkNotNull(reason, "reason"); - shutdownInternal(reason, true); - } - - @Override - @GuardedBy("this") - void notifyShutdown(Status status) { - clientTransportListener.transportShutdown(status); - } - - @Override - @GuardedBy("this") - void notifyTerminated() { - if (numInUseStreams.getAndSet(0) > 0) { - clientTransportListener.transportInUse(false); - } - if (readyTimeoutFuture != null) { - readyTimeoutFuture.cancel(false); - readyTimeoutFuture = null; - } - serviceBinding.unbind(); - clientTransportListener.transportTerminated(); - } - - @Override - @GuardedBy("this") - protected void handleSetupTransport(Parcel parcel) { - int remoteUid = Binder.getCallingUid(); - attributes = setSecurityAttrs(attributes, remoteUid); - if (inState(TransportState.SETUP)) { - int version = parcel.readInt(); - IBinder binder = parcel.readStrongBinder(); - if (version != WIRE_FORMAT_VERSION) { - shutdownInternal( - Status.UNAVAILABLE.withDescription("Wire format version mismatch"), true); - } else if (binder == null) { - shutdownInternal( - Status.UNAVAILABLE.withDescription("Malformed SETUP_TRANSPORT data"), true); - } else { - ListenableFuture authFuture = - (securityPolicy instanceof AsyncSecurityPolicy) - ? ((AsyncSecurityPolicy) securityPolicy).checkAuthorizationAsync(remoteUid) - : Futures.submit( - () -> securityPolicy.checkAuthorization(remoteUid), offloadExecutor); - Futures.addCallback( - authFuture, - new FutureCallback() { - @Override - public void onSuccess(Status result) { - handleAuthResult(binder, result); - } - - @Override - public void onFailure(Throwable t) { - handleAuthResult(t); - } - }, - offloadExecutor); - } - } - } - - private synchronized void handleAuthResult(IBinder binder, Status authorization) { - if (inState(TransportState.SETUP)) { - if (!authorization.isOk()) { - shutdownInternal(authorization, true); - } else if (!setOutgoingBinder(OneWayBinderProxy.wrap(binder, offloadExecutor))) { - shutdownInternal( - Status.UNAVAILABLE.withDescription("Failed to observe outgoing binder"), true); - } else { - // Check state again, since a failure inside setOutgoingBinder (or a callback it - // triggers), could have shut us down. - if (!isShutdown()) { - setState(TransportState.READY); - attributes = clientTransportListener.filterTransport(attributes); - clientTransportListener.transportReady(); - if (readyTimeoutFuture != null) { - readyTimeoutFuture.cancel(false); - readyTimeoutFuture = null; - } - } - } - } - } - - private synchronized void handleAuthResult(Throwable t) { - shutdownInternal( - Status.INTERNAL.withDescription("Could not evaluate SecurityPolicy").withCause(t), true); - } - - @GuardedBy("this") - @Override - protected void handlePingResponse(Parcel parcel) { - pingTracker.onPingResponse(parcel.readInt()); - } - - private static ClientStream newFailingClientStream( - Status failure, Attributes attributes, Metadata headers, ClientStreamTracer[] tracers) { - StatsTraceContext statsTraceContext = - StatsTraceContext.newClientContext(tracers, attributes, headers); - statsTraceContext.clientOutboundHeaders(); - return new FailingClientStream(failure, tracers); - } - - private static InternalLogId buildLogId( - Context sourceContext, AndroidComponentAddress targetAddress) { - return InternalLogId.allocate( - BinderClientTransport.class, - sourceContext.getClass().getSimpleName() + "->" + targetAddress); - } - - private static Attributes buildClientAttributes( - Attributes eagAttrs, - Context sourceContext, - AndroidComponentAddress targetAddress, - InboundParcelablePolicy inboundParcelablePolicy) { - return Attributes.newBuilder() - .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE) // Trust noone for now. - .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs) - .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, AndroidComponentAddress.forContext(sourceContext)) - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, targetAddress) - .set(INBOUND_PARCELABLE_POLICY, inboundParcelablePolicy) - .build(); - } - - private static Attributes setSecurityAttrs(Attributes attributes, int uid) { - return attributes.toBuilder() - .set(REMOTE_UID, uid) - .set( - GrpcAttributes.ATTR_SECURITY_LEVEL, - uid == Process.myUid() - ? SecurityLevel.PRIVACY_AND_INTEGRITY - : SecurityLevel.INTEGRITY) // TODO: Have the SecrityPolicy decide this. - .build(); - } - } - - /** Concrete server-side transport implementation. */ - @Internal - public static final class BinderServerTransport extends BinderTransport - implements ServerTransport { - - private final List streamTracerFactories; - @Nullable private ServerTransportListener serverTransportListener; - - /** - * Constructs a new transport instance. - * - * @param binderDecorator used to decorate 'callbackBinder', for fault injection. - */ - public BinderServerTransport( - ObjectPool executorServicePool, - Attributes attributes, - List streamTracerFactories, - OneWayBinderProxy.Decorator binderDecorator, - IBinder callbackBinder) { - super(executorServicePool, attributes, binderDecorator, buildLogId(attributes)); - this.streamTracerFactories = streamTracerFactories; - // TODO(jdcormie): Plumb in the Server's executor() and use it here instead. - setOutgoingBinder(OneWayBinderProxy.wrap(callbackBinder, getScheduledExecutorService())); - } - - public synchronized void setServerTransportListener( - ServerTransportListener serverTransportListener) { - this.serverTransportListener = serverTransportListener; - if (isShutdown()) { - setState(TransportState.SHUTDOWN_TERMINATED); - notifyTerminated(); - releaseExecutors(); - } else { - sendSetupTransaction(); - // Check we're not shutdown again, since a failure inside sendSetupTransaction (or a - // callback it triggers), could have shut us down. - if (!isShutdown()) { - setState(TransportState.READY); - attributes = serverTransportListener.transportReady(attributes); - } - } - } - - StatsTraceContext createStatsTraceContext(String methodName, Metadata headers) { - return StatsTraceContext.newServerContext(streamTracerFactories, methodName, headers); - } - - synchronized Status startStream(ServerStream stream, String methodName, Metadata headers) { - if (isShutdown()) { - return Status.UNAVAILABLE.withDescription("transport is shutdown"); - } else { - serverTransportListener.streamCreated(stream, methodName, headers); - return Status.OK; - } - } - - @Override - @GuardedBy("this") - void notifyShutdown(Status status) { - // Nothing to do. - } - - @Override - @GuardedBy("this") - void notifyTerminated() { - if (serverTransportListener != null) { - serverTransportListener.transportTerminated(); - } - } - - @Override - public synchronized void shutdown() { - shutdownInternal(Status.OK, false); - } - - @Override - public synchronized void shutdownNow(Status reason) { - shutdownInternal(reason, true); - } - - @Override - @Nullable - @GuardedBy("this") - protected Inbound createInbound(int callId) { - return new Inbound.ServerInbound(this, attributes, callId); - } - - private static InternalLogId buildLogId(Attributes attributes) { - return InternalLogId.allocate( - BinderServerTransport.class, "from " + attributes.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); - } - } - private static void checkTransition(TransportState current, TransportState next) { switch (next) { case SETUP: @@ -986,6 +602,11 @@ Map> getOngoingCalls() { return ongoingCalls; } + @VisibleForTesting + synchronized LeakSafeOneWayBinder getIncomingBinderForTesting() { + return this.incomingBinder; + } + private static Status statusFromRemoteException(RemoteException e) { if (e instanceof DeadObjectException || e instanceof TransactionTooLargeException) { // These are to be expected from time to time and can simply be retried. diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java index 430eee3e041..6f95ef8a83c 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java @@ -20,6 +20,7 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.Internal; import io.grpc.Metadata; @@ -35,7 +36,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; /** diff --git a/binder/src/main/java/io/grpc/binder/internal/FlowController.java b/binder/src/main/java/io/grpc/binder/internal/FlowController.java index 1972ea00e6c..135f363a01e 100644 --- a/binder/src/main/java/io/grpc/binder/internal/FlowController.java +++ b/binder/src/main/java/io/grpc/binder/internal/FlowController.java @@ -15,7 +15,7 @@ */ package io.grpc.binder.internal; -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; /** Keeps track of the number of bytes on the wire in a single direction. */ final class FlowController { diff --git a/binder/src/main/java/io/grpc/binder/internal/Inbound.java b/binder/src/main/java/io/grpc/binder/internal/Inbound.java index 19c0e4a0f08..9b9dfeef5ce 100644 --- a/binder/src/main/java/io/grpc/binder/internal/Inbound.java +++ b/binder/src/main/java/io/grpc/binder/internal/Inbound.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState; import android.os.Parcel; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.Metadata; import io.grpc.Status; @@ -34,7 +35,6 @@ import java.io.InputStream; import java.util.ArrayList; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Handles incoming binder transactions for a single stream, turning those transactions into calls @@ -610,10 +610,9 @@ protected void deliverCloseAbnormal(Status status) { // Server-side inbound transactions. static final class ServerInbound extends Inbound { - private final BinderTransport.BinderServerTransport serverTransport; + private final BinderServerTransport serverTransport; - ServerInbound( - BinderTransport.BinderServerTransport transport, Attributes attributes, int callId) { + ServerInbound(BinderServerTransport transport, Attributes attributes, int callId) { super(transport, attributes, callId); this.serverTransport = transport; } diff --git a/binder/src/main/java/io/grpc/binder/internal/IntentNameResolver.java b/binder/src/main/java/io/grpc/binder/internal/IntentNameResolver.java new file mode 100644 index 00000000000..ce3e2a96a42 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/IntentNameResolver.java @@ -0,0 +1,299 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static io.grpc.binder.internal.SystemApis.createContextAsUser; + +import android.annotation.SuppressLint; +import android.content.BroadcastReceiver; +import android.content.ComponentName; +import android.content.Context; +import android.content.Intent; +import android.content.IntentFilter; +import android.content.pm.PackageManager; +import android.content.pm.ResolveInfo; +import android.os.Build; +import android.os.UserHandle; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.StatusOr; +import io.grpc.SynchronizationContext; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.ApiConstants; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Executor; +import javax.annotation.Nullable; + +/** + * A {@link NameResolver} that resolves Android-standard "intent:" target URIs to the list of {@link + * AndroidComponentAddress} that match it by manifest intent filter. + */ +final class IntentNameResolver extends NameResolver { + private final Intent targetIntent; // Never mutated. + @Nullable private final UserHandle targetUser; // null means same user that hosts this process. + private final Context targetUserContext; + private final Executor offloadExecutor; + private final Executor sequentialExecutor; + private final SynchronizationContext syncContext; + private final ServiceConfigParser serviceConfigParser; + + // Accessed only on `sequentialExecutor` + @Nullable private PackageChangeReceiver receiver; // != null when registered + + // Accessed only on 'syncContext'. + private boolean shutdown; + private boolean queryNeeded; + @Nullable private Listener2 listener; // != null after start(). + @Nullable private ListenableFuture queryResultFuture; // != null when querying. + + @EquivalentAddressGroup.Attr + private static final Attributes CONSTANT_EAG_ATTRS = + Attributes.newBuilder() + // Servers discovered in PackageManager are especially untrusted. After all, any app can + // declare any intent filter it wants! Require pre-authorization so that unauthorized apps + // don't even get a chance to run onCreate()/onBind(). + .set(ApiConstants.PRE_AUTH_SERVER_OVERRIDE, true) + .build(); + + IntentNameResolver(Intent targetIntent, Args args) { + this.targetIntent = targetIntent; + this.targetUser = args.getArg(ApiConstants.TARGET_ANDROID_USER); + Context context = + checkNotNull(args.getArg(ApiConstants.SOURCE_ANDROID_CONTEXT), "SOURCE_ANDROID_CONTEXT") + .getApplicationContext(); + this.targetUserContext = + targetUser != null ? createContextForTargetUserOrThrow(context, targetUser) : context; + // This Executor is nominally optional but all grpc-java Channels provide it since 1.25. + this.offloadExecutor = + checkNotNull(args.getOffloadExecutor(), "NameResolver.Args.getOffloadExecutor()"); + // Ensures start()'s work runs before resolve()'s' work, and both run before shutdown()'s. + this.sequentialExecutor = MoreExecutors.newSequentialExecutor(offloadExecutor); + this.syncContext = args.getSynchronizationContext(); + this.serviceConfigParser = args.getServiceConfigParser(); + } + + private static Context createContextForTargetUserOrThrow(Context context, UserHandle targetUser) { + try { + return createContextAsUser(context, targetUser, /* flags= */ 0); // @SystemApi since R. + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException( + "TARGET_ANDROID_USER NameResolver.Arg requires SDK_INT >= R and @SystemApi visibility"); + } + } + + @Override + public void start(Listener2 listener) { + checkState(this.listener == null, "Already started!"); + checkState(!shutdown, "Resolver is shutdown"); + this.listener = checkNotNull(listener); + sequentialExecutor.execute(this::registerReceiver); + resolve(); + } + + @Override + public void refresh() { + checkState(listener != null, "Not started!"); + resolve(); + } + + private void resolve() { + syncContext.throwIfNotInThisSynchronizationContext(); + + if (shutdown) { + return; + } + + // We can't block here in 'syncContext' so we offload PackageManager queries to an Executor. + // But offloading complicates things a bit because other calls can arrive while we wait for the + // results. We keep 'listener' up-to-date with the latest state in PackageManager by doing: + // 1. Only one query-and-report-to-listener operation at a time. + // 2. At least one query-and-report-to-listener AFTER every PackageManager state change. + if (queryResultFuture == null) { + queryResultFuture = Futures.submit(this::queryPackageManager, sequentialExecutor); + queryResultFuture.addListener(this::onQueryComplete, syncContext); + } else { + // There's already a query in-flight but (2) says we need at least one more. Our sequential + // Executor would be enough to ensure (1) but we also don't want a backlog of work to build up + // if things change rapidly. Just make a note to start a new query when this one finishes. + queryNeeded = true; + } + } + + private void onQueryComplete() { + syncContext.throwIfNotInThisSynchronizationContext(); + checkState(queryResultFuture != null); + checkState(queryResultFuture.isDone()); + + // Capture non-final `listener` here while we're on 'syncContext'. + Listener2 listener = checkNotNull(this.listener); + Futures.addCallback( + queryResultFuture, // Already isDone() so this execute()s immediately. + new FutureCallback() { + @Override + public void onSuccess(ResolutionResult result) { + listener.onResult2(result); + } + + @Override + public void onFailure(Throwable t) { + listener.onResult2( + ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromStatus(Status.fromThrowable(t))) + .build()); + } + }, + syncContext); // Already on 'syncContext' but addCallback() is faster than try/get/catch. + queryResultFuture = null; + + if (queryNeeded) { + // One or more resolve() requests arrived while we were working on the last one. Just one + // follow-on query can subsume all of them. + queryNeeded = false; + resolve(); + } + } + + @Override + public String getServiceAuthority() { + return "localhost"; + } + + @Override + public void shutdown() { + syncContext.throwIfNotInThisSynchronizationContext(); + if (!shutdown) { + shutdown = true; + sequentialExecutor.execute(this::maybeUnregisterReceiver); + } + } + + private ResolutionResult queryPackageManager() throws StatusException { + List queryResults = queryIntentServices(targetIntent); + + // Avoid a spurious UnsafeIntentLaunchViolation later. Since S, Android's StrictMode is very + // conservative, marking any Intent parsed from a string as suspicious and complaining when you + // bind to it. But all this is pointless with grpc-binder, which already goes even further by + // not trusting addresses at all! Instead, we rely on SecurityPolicy, which won't allow a + // connection to an unauthorized server UID no matter how you got there. + Intent prototypeBindIntent = sanitize(targetIntent); + + // Model each matching android.app.Service as an EAG (server) with a single address. + List addresses = new ArrayList<>(); + for (ResolveInfo resolveInfo : queryResults) { + prototypeBindIntent.setComponent( + new ComponentName(resolveInfo.serviceInfo.packageName, resolveInfo.serviceInfo.name)); + addresses.add( + new EquivalentAddressGroup( + AndroidComponentAddress.newBuilder() + .setBindIntent(prototypeBindIntent) // Makes a copy. + .setTargetUser(targetUser) + .build(), + CONSTANT_EAG_ATTRS)); + } + + return ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(addresses)) + // Empty service config means we get the default 'pick_first' load balancing policy. + .setServiceConfig(serviceConfigParser.parseServiceConfig(ImmutableMap.of())) + .build(); + } + + private List queryIntentServices(Intent intent) throws StatusException { + int flags = 0; + if (Build.VERSION.SDK_INT >= 29) { + // Don't match direct-boot-unaware Services that can't presently be created. We'll query again + // after the user is unlocked. The MATCH_DIRECT_BOOT_AUTO behavior is actually the default but + // being explicit here avoids an android.os.strictmode.ImplicitDirectBootViolation. + flags |= PackageManager.MATCH_DIRECT_BOOT_AUTO; + } + + List intentServices = + targetUserContext.getPackageManager().queryIntentServices(intent, flags); + if (intentServices == null || intentServices.isEmpty()) { + // Must be the same as when ServiceBinding's call to bindService() returns false. + throw Status.UNIMPLEMENTED + .withDescription("Service not found for intent " + intent) + .asException(); + } + return intentServices; + } + + // Returns a new Intent with the same action, data and categories as 'input'. + private static Intent sanitize(Intent input) { + Intent output = new Intent(); + output.setAction(input.getAction()); + output.setData(input.getData()); + + Set categories = input.getCategories(); + if (categories != null) { + for (String category : categories) { + output.addCategory(category); + } + } + // Don't bother copying extras and flags since AndroidComponentAddress (rightly) ignores them. + // Don't bother copying package or ComponentName either, since we're about to set that. + return output; + } + + final class PackageChangeReceiver extends BroadcastReceiver { + @Override + public void onReceive(Context context, Intent intent) { + // Get off the main thread and into the correct SynchronizationContext. + syncContext.executeLater(IntentNameResolver.this::resolve); + offloadExecutor.execute(syncContext::drain); + } + } + + @SuppressLint("UnprotectedReceiver") // All of these are protected system broadcasts. + private void registerReceiver() { + checkState(receiver == null, "Already registered!"); + receiver = new PackageChangeReceiver(); + IntentFilter filter = new IntentFilter(); + filter.addDataScheme("package"); + filter.addAction(Intent.ACTION_PACKAGE_ADDED); + filter.addAction(Intent.ACTION_PACKAGE_CHANGED); + filter.addAction(Intent.ACTION_PACKAGE_REMOVED); + filter.addAction(Intent.ACTION_PACKAGE_REPLACED); + + targetUserContext.registerReceiver(receiver, filter); + + if (Build.VERSION.SDK_INT >= 24) { + // Clients running in direct boot mode must refresh() when the user is unlocked because + // that's when `directBootAware=false` services become visible in queryIntentServices() + // results. ACTION_BOOT_COMPLETED would work too but it's delivered with lower priority. + targetUserContext.registerReceiver(receiver, new IntentFilter(Intent.ACTION_USER_UNLOCKED)); + } + } + + private void maybeUnregisterReceiver() { + if (receiver != null) { // NameResolver API contract appears to allow shutdown without start(). + targetUserContext.unregisterReceiver(receiver); + receiver = null; + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/IntentNameResolverProvider.java b/binder/src/main/java/io/grpc/binder/internal/IntentNameResolverProvider.java new file mode 100644 index 00000000000..5a3c9fcc986 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/IntentNameResolverProvider.java @@ -0,0 +1,88 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import static android.content.Intent.URI_INTENT_SCHEME; + +import android.content.Intent; +import com.google.common.collect.ImmutableSet; +import io.grpc.NameResolver; +import io.grpc.Uri; +import io.grpc.NameResolver.Args; +import io.grpc.NameResolverProvider; +import io.grpc.binder.AndroidComponentAddress; +import java.net.SocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * A {@link NameResolverProvider} that handles Android-standard "intent:" target URIs, resolving + * them to the list of {@link AndroidComponentAddress} that match by manifest intent filter. + */ +public final class IntentNameResolverProvider extends NameResolverProvider { + + static final String ANDROID_INTENT_SCHEME = "intent"; + + @Override + public String getDefaultScheme() { + return ANDROID_INTENT_SCHEME; + } + + @Nullable + @Override + public NameResolver newNameResolver(URI targetUri, final Args args) { + if (Objects.equals(targetUri.getScheme(), ANDROID_INTENT_SCHEME)) { + return new IntentNameResolver(parseUriArg(targetUri.toString()), args); + } else { + return null; + } + } + + @Nullable + @Override + public NameResolver newNameResolver(Uri targetUri, final Args args) { + if (Objects.equals(targetUri.getScheme(), ANDROID_INTENT_SCHEME)) { + return new IntentNameResolver(parseUriArg(targetUri.toString()), args); + } else { + return null; + } + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int priority() { + return 3; // Lower than DNS so we don't accidentally become the default scheme for a registry. + } + + @Override + public ImmutableSet> getProducedSocketAddressTypes() { + return ImmutableSet.of(AndroidComponentAddress.class); + } + + private static Intent parseUriArg(String targetUri) { + try { + return Intent.parseUri(targetUri, URI_INTENT_SCHEME); + } catch (URISyntaxException e) { + throw new IllegalArgumentException(e); + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/LeakSafeOneWayBinder.java b/binder/src/main/java/io/grpc/binder/internal/LeakSafeOneWayBinder.java index a12a5cb13cc..c36bc7d5bd3 100644 --- a/binder/src/main/java/io/grpc/binder/internal/LeakSafeOneWayBinder.java +++ b/binder/src/main/java/io/grpc/binder/internal/LeakSafeOneWayBinder.java @@ -19,6 +19,7 @@ import android.os.Binder; import android.os.IBinder; import android.os.Parcel; +import androidx.annotation.BinderThread; import io.grpc.Internal; import java.util.logging.Level; import java.util.logging.Logger; @@ -58,6 +59,7 @@ public interface TransactionHandler { * @return the value to return from {@link Binder#onTransact}. NB: "oneway" semantics mean this * result will not delivered to the caller of {@link IBinder#transact} */ + @BinderThread boolean handleTransaction(int code, Parcel data); } @@ -71,7 +73,21 @@ public void detach() { setHandler(null); } - /** Replaces the current {@link TransactionHandler} with `handler`. */ + /** Returns the current {@link TransactionHandler} or null if already detached. */ + public @Nullable TransactionHandler getHandler() { + return handler; + } + + /** + * Replaces the current {@link TransactionHandler} with `handler`. + * + *

{@link TransactionHandler} mutations race against incoming transactions except in the + * special case where the caller is already handling an incoming transaction on this same {@link + * LeakSafeOneWayBinder} instance. In that case, mutations are safe and the provided 'handler' is + * guaranteed to be used for the very next transaction. This follows from the one-at-a-time + * property of one-way Binder transactions as explained by {@link + * TransactionHandler#handleTransaction}. + */ public void setHandler(@Nullable TransactionHandler handler) { this.handler = handler; } diff --git a/binder/src/main/java/io/grpc/binder/internal/Outbound.java b/binder/src/main/java/io/grpc/binder/internal/Outbound.java index e2896be02a1..7db5bf0fbe4 100644 --- a/binder/src/main/java/io/grpc/binder/internal/Outbound.java +++ b/binder/src/main/java/io/grpc/binder/internal/Outbound.java @@ -19,9 +19,9 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; -import static java.lang.Math.max; import android.os.Parcel; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Deadline; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -34,7 +34,6 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Sends the set of outbound transactions for a single BinderStream (rpc). @@ -397,8 +396,7 @@ protected int writeSuffix(Parcel parcel) throws IOException { @GuardedBy("this") void setDeadline(Deadline deadline) { headers.discardAll(TIMEOUT_KEY); - long effectiveTimeoutNanos = max(0, deadline.timeRemaining(TimeUnit.NANOSECONDS)); - headers.put(TIMEOUT_KEY, effectiveTimeoutNanos); + headers.put(TIMEOUT_KEY, deadline.timeRemaining(TimeUnit.NANOSECONDS)); } } diff --git a/binder/src/main/java/io/grpc/binder/internal/PingTracker.java b/binder/src/main/java/io/grpc/binder/internal/PingTracker.java index 33fcb43918f..5a4300443ba 100644 --- a/binder/src/main/java/io/grpc/binder/internal/PingTracker.java +++ b/binder/src/main/java/io/grpc/binder/internal/PingTracker.java @@ -17,12 +17,12 @@ package io.grpc.binder.internal; import com.google.common.base.Ticker; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Status; import io.grpc.StatusException; import io.grpc.internal.ClientTransport.PingCallback; import java.util.concurrent.Executor; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Tracks an ongoing ping request for a client-side binder transport. We only handle a single active @@ -99,7 +99,7 @@ private final class Ping { private synchronized void fail(Status status) { if (!done) { done = true; - executor.execute(() -> callback.onFailure(status.asException())); + executor.execute(() -> callback.onFailure(status)); } } diff --git a/binder/src/main/java/io/grpc/binder/internal/ServiceBinding.java b/binder/src/main/java/io/grpc/binder/internal/ServiceBinding.java index 76f1d7aa9f7..4b6bf7d06fb 100644 --- a/binder/src/main/java/io/grpc/binder/internal/ServiceBinding.java +++ b/binder/src/main/java/io/grpc/binder/internal/ServiceBinding.java @@ -17,24 +17,30 @@ package io.grpc.binder.internal; import static com.google.common.base.Preconditions.checkState; +import static io.grpc.binder.internal.SystemApis.createContextAsUser; import android.app.admin.DevicePolicyManager; import android.content.ComponentName; import android.content.Context; import android.content.Intent; import android.content.ServiceConnection; +import android.content.pm.PackageManager; +import android.content.pm.ResolveInfo; +import android.content.pm.ServiceInfo; +import android.os.Build; import android.os.IBinder; import android.os.UserHandle; import androidx.annotation.AnyThread; import androidx.annotation.MainThread; import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Status; +import io.grpc.StatusException; import io.grpc.binder.BinderChannelCredentials; import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -96,6 +102,9 @@ public String methodName() { private State reportedState; // Only used on the main thread. + @GuardedBy("this") + private ComponentName connectedServiceName; + @AnyThread ServiceBinding( Executor mainThreadExecutor, @@ -183,18 +192,29 @@ private static Status bindInternal( bindResult = context.bindService(bindIntent, conn, flags); break; case BIND_SERVICE_AS_USER: - bindResult = context.bindServiceAsUser(bindIntent, conn, flags, targetUserHandle); + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) { + // We don't need SystemApis because bindServiceAsUser() is simply public in R+. + bindResult = context.bindServiceAsUser(bindIntent, conn, flags, targetUserHandle); + } else { + // TODO(#12279): Use SystemApis to make this work pre-R. + return Status.INTERNAL.withDescription("Cross user Channel requires Android R+"); + } break; case DEVICE_POLICY_BIND_SEVICE_ADMIN: DevicePolicyManager devicePolicyManager = (DevicePolicyManager) context.getSystemService(Context.DEVICE_POLICY_SERVICE); - bindResult = - devicePolicyManager.bindDeviceAdminServiceAsUser( - channelCredentials.getDevicePolicyAdminComponentName(), - bindIntent, - conn, - flags, - targetUserHandle); + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) { + bindResult = + devicePolicyManager.bindDeviceAdminServiceAsUser( + channelCredentials.getDevicePolicyAdminComponentName(), + bindIntent, + conn, + flags, + targetUserHandle); + } else { + return Status.INTERNAL.withDescription( + "Device policy admin binding requires Android R+"); + } break; } if (!bindResult) { @@ -247,11 +267,67 @@ void unbindInternal(Status reason) { } } + @AnyThread + @Override + public ServiceInfo resolve() throws StatusException { + int flags = 0; + if (Build.VERSION.SDK_INT >= 29) { + // Filter out non-'directBootAware' s when 'targetUserHandle' is locked. Here's why: + // Callers want 'bindIntent' to #resolve() to the same thing a follow-up call to #bind() will. + // But bindService() *always* ignores services that can't presently be created for lack of + // 'directBootAware'-ness. This flag explicitly tells resolveService() to act the same way. + flags |= PackageManager.MATCH_DIRECT_BOOT_AUTO; + } + ResolveInfo resolveInfo = + getContextForTargetUser("Cross-user pre-auth") + .getPackageManager() + .resolveService(bindIntent, flags); + if (resolveInfo == null) { + throw Status.UNIMPLEMENTED // Same status code as when bindService() returns false. + .withDescription("resolveService(" + bindIntent + " / " + targetUserHandle + ") was null") + .asException(); + } + return resolveInfo.serviceInfo; + } + + private Context getContextForTargetUser(String purpose) throws StatusException { + checkState(sourceContext != null, "Already unbound!"); + try { + return targetUserHandle == null + ? sourceContext + : createContextAsUser(sourceContext, targetUserHandle, /* flags= */ 0); + } catch (ReflectiveOperationException e) { + throw Status.INTERNAL + .withDescription(purpose + " requires SDK_INT >= R and @SystemApi visibility") + .asException(); + } + } + @MainThread private void clearReferences() { sourceContext = null; } + @AnyThread + @Override + public ServiceInfo getConnectedServiceInfo() throws StatusException { + try { + return getContextForTargetUser("cross-user v2 handshake") + .getPackageManager() + .getServiceInfo(getConnectedServiceName(), /* flags= */ 0); + } catch (PackageManager.NameNotFoundException e) { + throw Status.UNIMPLEMENTED + .withCause(e) + .withDescription("connected remote service was uninstalled/disabled during handshake") + .asException(); + } + } + + private synchronized ComponentName getConnectedServiceName() { + checkState(connectedServiceName != null, "onBound() not yet called!"); + return connectedServiceName; + } + @Override @MainThread public void onServiceConnected(ComponentName className, IBinder binder) { @@ -259,6 +335,7 @@ public void onServiceConnected(ComponentName className, IBinder binder) { synchronized (this) { if (state == State.BINDING) { state = State.BOUND; + connectedServiceName = className; bound = true; } } @@ -272,19 +349,32 @@ public void onServiceConnected(ComponentName className, IBinder binder) { @Override @MainThread public void onServiceDisconnected(ComponentName name) { - unbindInternal(Status.UNAVAILABLE.withDescription("onServiceDisconnected: " + name)); + unbindInternal( + Status.UNAVAILABLE.withDescription( + "Server process crashed, exited or was killed (onServiceDisconnected): " + name)); } @Override @MainThread public void onNullBinding(ComponentName name) { - unbindInternal(Status.UNIMPLEMENTED.withDescription("onNullBinding: " + name)); + unbindInternal( + Status.UNIMPLEMENTED.withDescription( + "Remote Service returned null from onBind() for " + + bindIntent + + " (onNullBinding): " + + name)); } @Override @MainThread public void onBindingDied(ComponentName name) { - unbindInternal(Status.UNAVAILABLE.withDescription("onBindingDied: " + name)); + unbindInternal( + Status.UNAVAILABLE.withDescription( + "Remote Service component " + + name.getClassName() + + " was disabled, or its package " + + name.getPackageName() + + " was disabled, force-stopped, replaced or uninstalled (onBindingDied).")); } @VisibleForTesting diff --git a/binder/src/main/java/io/grpc/binder/internal/SimplePromise.java b/binder/src/main/java/io/grpc/binder/internal/SimplePromise.java new file mode 100644 index 00000000000..c7d227fbf64 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/SimplePromise.java @@ -0,0 +1,97 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import java.util.ArrayList; +import java.util.List; + +/** + * Placeholder for an object that will be provided later. + * + *

Similar to {@link com.google.common.util.concurrent.SettableFuture}, except it cannot fail or + * be cancelled. Most importantly, this class guarantees that {@link Listener}s run one-at-a-time + * and in the same order that they were scheduled. This conveniently matches the expectations of + * most listener interfaces in the io.grpc universe. + * + *

Not safe for concurrent use by multiple threads. Thread-compatible for callers that provide + * synchronization externally. + */ +public class SimplePromise { + private T value; + private List> pendingListeners; // Allocated lazily in the hopes it's never needed. + + /** + * Provides the promised object and runs any pending listeners. + * + * @throws IllegalStateException if this method has already been called + * @throws RuntimeException if some pending listener threw when we tried to run it + */ + public void set(T value) { + checkNotNull(value, "value"); + checkState(this.value == null, "Already set!"); + this.value = value; + if (pendingListeners != null) { + for (Listener listener : pendingListeners) { + listener.notify(value); + } + pendingListeners = null; + } + } + + /** + * Returns the promised object, under the assumption that it's already been set. + * + *

Compared to {@link #runWhenSet(Listener)}, this method may be a more efficient way to access + * the promised value in the case where you somehow know externally that {@link #set(T)} has + * "happened-before" this call. + * + * @throws IllegalStateException if {@link #set(T)} has not yet been called + */ + public T get() { + checkState(value != null, "Not yet set!"); + return value; + } + + /** + * Runs the given listener when this promise is fulfilled, or immediately if already fulfilled. + * + * @throws RuntimeException if already fulfilled and 'listener' threw when we tried to run it + */ + public void runWhenSet(Listener listener) { + if (value != null) { + listener.notify(value); + } else { + if (pendingListeners == null) { + pendingListeners = new ArrayList<>(); + } + pendingListeners.add(listener); + } + } + + /** + * An object that wants to get notified when a SimplePromise has been fulfilled. + */ + public interface Listener { + /** + * Indicates that the associated SimplePromise has been fulfilled with the given `value`. + */ + void notify(T value); + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/SystemApis.java b/binder/src/main/java/io/grpc/binder/internal/SystemApis.java new file mode 100644 index 00000000000..a4feec86a11 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/SystemApis.java @@ -0,0 +1,60 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import android.content.Context; +import android.os.UserHandle; +import java.lang.reflect.Method; + +/** + * A collection of static methods that wrap hidden Android "System APIs." + * + *

grpc-java can't call Android methods marked @SystemApi directly, even though many of our users + * are "system apps" entitled to do so. Being a library built outside the Android source tree, these + * "non-SDK" elements simply don't exist from our compiler's perspective. Instead we resort to + * reflection but use the static wrappers found here to keep call sites readable and type safe. + * + *

Modern Android's JRE also limits the visibility of these methods at *runtime*. Only certain + * privileged apps installed on the system image app can call them, even using reflection, and this + * wrapper doesn't change that. Callers are responsible for ensuring that the host app actually has + * the ability to call @SystemApis and all methods throw {@link ReflectiveOperationException} as a + * reminder to do that. See + * https://developer.android.com/guide/app-compatibility/restrictions-non-sdk-interfaces for more. + */ +final class SystemApis { + private static volatile Method createContextAsUserMethod; + + // Not to be instantiated. + private SystemApis() {} + + /** + * Returns a new Context object whose methods act as if they were running in the given user. + * + * @throws ReflectiveOperationException if SDK_INT < R or host app lacks @SystemApi visibility + */ + public static Context createContextAsUser(Context context, UserHandle userHandle, int flags) + throws ReflectiveOperationException { + if (createContextAsUserMethod == null) { + synchronized (SystemApis.class) { + if (createContextAsUserMethod == null) { + createContextAsUserMethod = + Context.class.getMethod("createContextAsUser", UserHandle.class, int.class); + } + } + } + return (Context) createContextAsUserMethod.invoke(context, userHandle, flags); + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/TransactionUtils.java b/binder/src/main/java/io/grpc/binder/internal/TransactionUtils.java index c962554d125..2777a78d4ac 100644 --- a/binder/src/main/java/io/grpc/binder/internal/TransactionUtils.java +++ b/binder/src/main/java/io/grpc/binder/internal/TransactionUtils.java @@ -16,9 +16,13 @@ package io.grpc.binder.internal; +import android.os.Binder; import android.os.Parcel; import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; +import java.util.logging.Level; +import java.util.logging.Logger; +import io.grpc.binder.internal.LeakSafeOneWayBinder.TransactionHandler; import javax.annotation.Nullable; /** Constants and helpers for managing inbound / outbound transactions. */ @@ -99,4 +103,24 @@ static void fillInFlags(Parcel parcel, int flags) { parcel.writeInt(flags); parcel.setDataPosition(pos); } + + /** + * Decorates the given {@link TransactionHandler} with a wrapper that only forwards transactions + * from the given `allowedCallingUid`. + */ + static TransactionHandler newCallerFilteringHandler( + int allowedCallingUid, TransactionHandler wrapped) { + final Logger logger = Logger.getLogger(TransactionUtils.class.getName()); + return new TransactionHandler() { + @Override + public boolean handleTransaction(int code, Parcel data) { + int callingUid = Binder.getCallingUid(); + if (callingUid != allowedCallingUid) { + logger.log(Level.WARNING, "dropped txn from " + callingUid + " !=" + allowedCallingUid); + return false; + } + return wrapped.handleTransaction(code, data); + } + }; + } } diff --git a/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java b/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java index 6d7e53e5a19..d7d77d7feb1 100644 --- a/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java +++ b/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java @@ -18,11 +18,14 @@ import static android.content.Intent.URI_ANDROID_APP_SCHEME; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import android.content.ComponentName; import android.content.Context; import android.content.Intent; import android.net.Uri; +import android.os.Parcel; +import android.os.UserHandle; import androidx.test.core.app.ApplicationProvider; import com.google.common.testing.EqualsTester; import java.net.URISyntaxException; @@ -83,6 +86,32 @@ public void testAsBindIntent() { assertThat(addr.asBindIntent().filterEquals(bindIntent)).isTrue(); } + @Test + public void testPostCreateIntentMutation() { + Intent bindIntent = new Intent().setAction("foo-action").setComponent(hostComponent); + AndroidComponentAddress addr = AndroidComponentAddress.forBindIntent(bindIntent); + bindIntent.setAction("bar-action"); + assertThat(addr.asBindIntent().getAction()).isEqualTo("foo-action"); + } + + @Test + public void testPostBuildIntentMutation() { + Intent bindIntent = new Intent().setAction("foo-action").setComponent(hostComponent); + AndroidComponentAddress addr = + AndroidComponentAddress.newBuilder().setBindIntent(bindIntent).build(); + bindIntent.setAction("bar-action"); + assertThat(addr.asBindIntent().getAction()).isEqualTo("foo-action"); + } + + @Test + public void testBuilderMissingRequired() { + IllegalStateException ise = + assertThrows( + IllegalStateException.class, + () -> AndroidComponentAddress.newBuilder().setTargetUser(newUserHandle(123)).build()); + assertThat(ise.getMessage()).contains("bindIntent"); + } + @Test @Config(sdk = 30) public void testAsAndroidAppUriSdk30() throws URISyntaxException { @@ -117,13 +146,21 @@ public void testEquality() { AndroidComponentAddress.forContext(appContext), AndroidComponentAddress.forLocalComponent(appContext, appContext.getClass()), AndroidComponentAddress.forRemoteComponent( - appContext.getPackageName(), appContext.getClass().getName())) + appContext.getPackageName(), appContext.getClass().getName()), + AndroidComponentAddress.newBuilder() + .setBindIntentFromComponent(hostComponent) + .setTargetUser(null) + .build()) .addEqualityGroup( AndroidComponentAddress.forRemoteComponent("appy.mcappface", ".McActivity")) .addEqualityGroup(AndroidComponentAddress.forLocalComponent(appContext, getClass())) .addEqualityGroup( AndroidComponentAddress.forBindIntent( - new Intent().setAction("custom-action").setComponent(hostComponent))) + new Intent().setAction("custom-action").setComponent(hostComponent)), + AndroidComponentAddress.newBuilder() + .setBindIntent(new Intent().setAction("custom-action").setComponent(hostComponent)) + .setTargetUser(null) + .build()) .addEqualityGroup( AndroidComponentAddress.forBindIntent( new Intent() @@ -133,6 +170,31 @@ public void testEquality() { .testEquals(); } + @Test + public void testUnequalTargetUsers() { + new EqualsTester() + .addEqualityGroup( + AndroidComponentAddress.newBuilder() + .setBindIntentFromComponent(hostComponent) + .setTargetUser(newUserHandle(10)) + .build(), + AndroidComponentAddress.newBuilder() + .setBindIntentFromComponent(hostComponent) + .setTargetUser(newUserHandle(10)) + .build()) + .addEqualityGroup( + AndroidComponentAddress.newBuilder() + .setBindIntentFromComponent(hostComponent) + .setTargetUser(newUserHandle(11)) + .build()) + .addEqualityGroup( + AndroidComponentAddress.newBuilder() + .setBindIntentFromComponent(hostComponent) + .setTargetUser(null) + .build()) + .testEquals(); + } + @Test @Config(sdk = 30) public void testPackageFilterEquality30AndUp() { @@ -163,4 +225,15 @@ public void testPackageFilterEqualityPre30() { .setComponent(new ComponentName("pkg", "cls")))) .testEquals(); } + + private static UserHandle newUserHandle(int userId) { + Parcel parcel = Parcel.obtain(); + try { + parcel.writeInt(userId); + parcel.setDataPosition(0); + return new UserHandle(parcel); + } finally { + parcel.recycle(); + } + } } diff --git a/binder/src/test/java/io/grpc/binder/RobolectricBinderSecurityTest.java b/binder/src/test/java/io/grpc/binder/RobolectricBinderSecurityTest.java index ab81fc6b6d0..ffd1d89e69c 100644 --- a/binder/src/test/java/io/grpc/binder/RobolectricBinderSecurityTest.java +++ b/binder/src/test/java/io/grpc/binder/RobolectricBinderSecurityTest.java @@ -22,12 +22,13 @@ import static org.robolectric.Shadows.shadowOf; import android.app.Application; -import android.content.ComponentName; -import android.content.Intent; -import android.os.IBinder; -import android.os.Looper; -import androidx.lifecycle.LifecycleService; +import android.content.pm.ApplicationInfo; +import android.content.pm.PackageInfo; +import android.content.pm.ServiceInfo; import androidx.test.core.app.ApplicationProvider; +import androidx.test.core.content.pm.ApplicationInfoBuilder; +import androidx.test.core.content.pm.PackageInfoBuilder; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; @@ -42,90 +43,143 @@ import io.grpc.ServerServiceDefinition; import io.grpc.Status; import io.grpc.StatusRuntimeException; -import io.grpc.binder.internal.MainThreadScheduledExecutorService; import io.grpc.protobuf.lite.ProtoLiteUtils; import io.grpc.stub.ClientCalls; import io.grpc.stub.ServerCalls; import java.io.IOException; import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.robolectric.Robolectric; -import org.robolectric.RobolectricTestRunner; -import org.robolectric.android.controller.ServiceController; - -@RunWith(RobolectricTestRunner.class) +import org.robolectric.ParameterizedRobolectricTestRunner; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameter; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameters; +import org.robolectric.annotation.LooperMode; +import org.robolectric.annotation.LooperMode.Mode; + +@RunWith(ParameterizedRobolectricTestRunner.class) +@LooperMode(Mode.INSTRUMENTATION_TEST) public final class RobolectricBinderSecurityTest { private static final String SERVICE_NAME = "fake_service"; private static final String FULL_METHOD_NAME = "fake_service/fake_method"; private final Application context = ApplicationProvider.getApplicationContext(); - private ServiceController controller; - private SomeService service; + private final ArrayBlockingQueue> statusesToSet = + new ArrayBlockingQueue<>(128); private ManagedChannel channel; + private Server server; + + @Parameter public boolean preAuthServersParam; + + @Parameters(name = "preAuthServersParam={0}") + public static ImmutableList data() { + return ImmutableList.of(true, false); + } @Before public void setUp() { - controller = Robolectric.buildService(SomeService.class); - service = controller.create().get(); + ApplicationInfo serverAppInfo = + ApplicationInfoBuilder.newBuilder().setPackageName(context.getPackageName()).build(); + serverAppInfo.uid = android.os.Process.myUid(); + PackageInfo serverPkgInfo = + PackageInfoBuilder.newBuilder() + .setPackageName(serverAppInfo.packageName) + .setApplicationInfo(serverAppInfo) + .build(); + shadowOf(context.getPackageManager()).installPackage(serverPkgInfo); + + ServiceInfo serviceInfo = new ServiceInfo(); + serviceInfo.name = "SomeService"; + serviceInfo.packageName = serverAppInfo.packageName; + serviceInfo.applicationInfo = serverAppInfo; + shadowOf(context.getPackageManager()).addOrUpdateService(serviceInfo); + + AndroidComponentAddress listenAddress = + AndroidComponentAddress.forRemoteComponent(serviceInfo.packageName, serviceInfo.name); + + MethodDescriptor methodDesc = getMethodDescriptor(); + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + respObserver.onNext(req); + respObserver.onCompleted(); + }); + ServerMethodDefinition methodDef = + ServerMethodDefinition.create(methodDesc, callHandler); + ServerServiceDefinition def = + ServerServiceDefinition.builder(SERVICE_NAME).addMethod(methodDef).build(); + + IBinderReceiver binderReceiver = new IBinderReceiver(); + server = + BinderServerBuilder.forAddress(listenAddress, binderReceiver) + .addService(def) + .securityPolicy( + ServerSecurityPolicy.newBuilder() + .servicePolicy( + SERVICE_NAME, + new AsyncSecurityPolicy() { + @Override + public ListenableFuture checkAuthorizationAsync(int uid) { + SettableFuture status = SettableFuture.create(); + statusesToSet.add(status); + return status; + } + }) + .build()) + .build(); + try { + server.start(); + } catch (IOException e) { + throw new IllegalStateException(e); + } - AndroidComponentAddress listenAddress = AndroidComponentAddress.forContext(service); - ScheduledExecutorService executor = service.getExecutor(); + shadowOf(context) + .setComponentNameAndServiceForBindServiceForIntent( + listenAddress.asBindIntent(), + listenAddress.getComponent(), + checkNotNull(binderReceiver.get())); channel = BinderChannelBuilder.forAddress(listenAddress, context) - .executor(executor) - .scheduledExecutorService(executor) - .offloadExecutor(executor) + .preAuthorizeServers(preAuthServersParam) .build(); - idleLoopers(); } @After public void tearDown() { channel.shutdownNow(); - controller.destroy(); + server.shutdownNow(); } @Test public void testAsyncServerSecurityPolicy_failed_returnsFailureStatus() throws Exception { ListenableFuture status = makeCall(); - service.setSecurityPolicyStatusWhenReady(Status.ALREADY_EXISTS); - idleLoopers(); + statusesToSet.take().set(Status.ALREADY_EXISTS); - assertThat(Futures.getDone(status).getCode()).isEqualTo(Status.Code.ALREADY_EXISTS); + assertThat(status.get().getCode()).isEqualTo(Status.Code.ALREADY_EXISTS); } @Test public void testAsyncServerSecurityPolicy_failedFuture_failsWithCodeInternal() throws Exception { ListenableFuture status = makeCall(); - service.setSecurityPolicyFailed(new IllegalStateException("oops")); - idleLoopers(); + statusesToSet.take().setException(new IllegalStateException("oops")); - assertThat(Futures.getDone(status).getCode()).isEqualTo(Status.Code.INTERNAL); + assertThat(status.get().getCode()).isEqualTo(Status.Code.INTERNAL); } @Test public void testAsyncServerSecurityPolicy_allowed_returnsOkStatus() throws Exception { ListenableFuture status = makeCall(); - service.setSecurityPolicyStatusWhenReady(Status.OK); - idleLoopers(); + statusesToSet.take().set(Status.OK); - assertThat(Futures.getDone(status).getCode()).isEqualTo(Status.Code.OK); + assertThat(status.get().getCode()).isEqualTo(Status.Code.OK); } private ListenableFuture makeCall() { - ClientCall call = - channel.newCall( - getMethodDescriptor(), CallOptions.DEFAULT.withExecutor(service.getExecutor())); + ClientCall call = channel.newCall(getMethodDescriptor(), CallOptions.DEFAULT); ListenableFuture responseFuture = ClientCalls.futureUnaryCall(call, Empty.getDefaultInstance()); - idleLoopers(); - return Futures.catching( Futures.transform(responseFuture, unused -> Status.OK, directExecutor()), StatusRuntimeException.class, @@ -133,10 +187,6 @@ private ListenableFuture makeCall() { directExecutor()); } - private static void idleLoopers() { - shadowOf(Looper.getMainLooper()).idle(); - } - private static MethodDescriptor getMethodDescriptor() { MethodDescriptor.Marshaller marshaller = ProtoLiteUtils.marshaller(Empty.getDefaultInstance()); @@ -147,106 +197,4 @@ private static MethodDescriptor getMethodDescriptor() { .setSampledToLocalTracing(true) .build(); } - - private static class SomeService extends LifecycleService { - - private final IBinderReceiver binderReceiver = new IBinderReceiver(); - private final ArrayBlockingQueue> statusesToSet = - new ArrayBlockingQueue<>(128); - private Server server; - private final ScheduledExecutorService scheduledExecutorService = - new MainThreadScheduledExecutorService(); - - @Override - public void onCreate() { - super.onCreate(); - - MethodDescriptor methodDesc = getMethodDescriptor(); - ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall( - (req, respObserver) -> { - respObserver.onNext(req); - respObserver.onCompleted(); - }); - ServerMethodDefinition methodDef = - ServerMethodDefinition.create(methodDesc, callHandler); - ServerServiceDefinition def = - ServerServiceDefinition.builder(SERVICE_NAME).addMethod(methodDef).build(); - - server = - BinderServerBuilder.forAddress(AndroidComponentAddress.forContext(this), binderReceiver) - .addService(def) - .securityPolicy( - ServerSecurityPolicy.newBuilder() - .servicePolicy( - SERVICE_NAME, - new AsyncSecurityPolicy() { - @Override - public ListenableFuture checkAuthorizationAsync(int uid) { - return Futures.submitAsync( - () -> { - SettableFuture status = SettableFuture.create(); - statusesToSet.add(status); - return status; - }, - getExecutor()); - } - }) - .build()) - .executor(getExecutor()) - .scheduledExecutorService(getExecutor()) - .build(); - try { - server.start(); - } catch (IOException e) { - throw new IllegalStateException(e); - } - - Application context = ApplicationProvider.getApplicationContext(); - ComponentName componentName = new ComponentName(context, SomeService.class); - shadowOf(context) - .setComponentNameAndServiceForBindService( - componentName, checkNotNull(binderReceiver.get())); - } - - /** - * Returns an {@link ScheduledExecutorService} under which all of the gRPC computations run. The - * execution of any pending tasks on this executor can be triggered via {@link #idleLoopers()}. - */ - ScheduledExecutorService getExecutor() { - return scheduledExecutorService; - } - - void setSecurityPolicyStatusWhenReady(Status status) { - getNextEnqueuedStatus().set(status); - } - - void setSecurityPolicyFailed(Exception e) { - getNextEnqueuedStatus().setException(e); - } - - private SettableFuture getNextEnqueuedStatus() { - @Nullable SettableFuture future = statusesToSet.poll(); - while (future == null) { - // Keep idling until the future is available. - idleLoopers(); - future = statusesToSet.poll(); - } - return checkNotNull(future); - } - - @Override - public IBinder onBind(Intent intent) { - super.onBind(intent); - return checkNotNull(binderReceiver.get()); - } - - @Override - public void onDestroy() { - super.onDestroy(); - server.shutdownNow(); - } - - /** A future representing a task submitted to a {@link Handler}. */ - } } diff --git a/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java b/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java index 84c76a84bf2..71180ed43c5 100644 --- a/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java +++ b/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java @@ -357,7 +357,7 @@ public void testIsDeviceOwner_failsWhenNoPackagesForUid() throws Exception { } @Test - @Config(sdk = 21) + @Config(sdk = Config.OLDEST_SDK) public void testIsProfileOwner_succeedsForProfileOwner() throws Exception { PackageInfo info = newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); @@ -371,7 +371,7 @@ public void testIsProfileOwner_succeedsForProfileOwner() throws Exception { } @Test - @Config(sdk = 21) + @Config(sdk = Config.OLDEST_SDK) public void testIsProfileOwner_failsForNotProfileOwner() throws Exception { PackageInfo info = newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); @@ -385,7 +385,7 @@ public void testIsProfileOwner_failsForNotProfileOwner() throws Exception { } @Test - @Config(sdk = 21) + @Config(sdk = Config.OLDEST_SDK) public void testIsProfileOwner_failsWhenNoPackagesForUid() throws Exception { policy = SecurityPolicies.isProfileOwner(appContext); @@ -425,7 +425,7 @@ public void testIsProfileOwnerOnOrgOwned_failsForProfileOwnerOnNonOrgOwned() thr } @Test - @Config(sdk = 21) + @Config(sdk = Config.OLDEST_SDK) public void testIsProfileOwnerOnOrgOwned_failsForNotProfileOwner() throws Exception { PackageInfo info = newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); @@ -439,7 +439,7 @@ public void testIsProfileOwnerOnOrgOwned_failsForNotProfileOwner() throws Except } @Test - @Config(sdk = 21) + @Config(sdk = Config.OLDEST_SDK) public void testIsProfileOwnerOnOrgOwned_failsWhenNoPackagesForUid() throws Exception { policy = SecurityPolicies.isProfileOwnerOnOrganizationOwnedDevice(appContext); diff --git a/binder/src/test/java/io/grpc/binder/internal/BinderServerTransportTest.java b/binder/src/test/java/io/grpc/binder/internal/BinderServerTransportTest.java index e56d860c091..d261ce43c8c 100644 --- a/binder/src/test/java/io/grpc/binder/internal/BinderServerTransportTest.java +++ b/binder/src/test/java/io/grpc/binder/internal/BinderServerTransportTest.java @@ -16,25 +16,25 @@ package io.grpc.binder.internal; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.isNull; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.doThrow; import static org.robolectric.Shadows.shadowOf; -import static org.robolectric.annotation.LooperMode.Mode.PAUSED; import android.os.IBinder; import android.os.Looper; import android.os.Parcel; +import android.os.RemoteException; import com.google.common.collect.ImmutableList; import io.grpc.Attributes; -import io.grpc.Metadata; +import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.internal.FixedObjectPool; -import io.grpc.internal.ServerStream; -import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.MockServerTransportListener; +import io.grpc.internal.ObjectPool; +import java.util.List; import java.util.concurrent.ScheduledExecutorService; import org.junit.Before; import org.junit.Rule; @@ -44,75 +44,126 @@ import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; import org.robolectric.RobolectricTestRunner; -import org.robolectric.annotation.LooperMode; /** * Low-level server-side transport tests for binder channel. Like BinderChannelSmokeTest, this * convers edge cases not exercised by AbstractTransportTest, but it deals with the * binderTransport.BinderServerTransport directly. */ -@LooperMode(PAUSED) @RunWith(RobolectricTestRunner.class) public final class BinderServerTransportTest { @Rule public MockitoRule mocks = MockitoJUnit.rule(); private final ScheduledExecutorService executorService = new MainThreadScheduledExecutorService(); - private final TestTransportListener transportListener = new TestTransportListener(); + private MockServerTransportListener transportListener; @Mock IBinder mockBinder; - BinderTransport.BinderServerTransport transport; + BinderServerTransport transport; @Before public void setUp() throws Exception { - transport = - new BinderTransport.BinderServerTransport( - new FixedObjectPool<>(executorService), - Attributes.EMPTY, - ImmutableList.of(), - OneWayBinderProxy.IDENTITY_DECORATOR, - mockBinder); + transportListener = new MockServerTransportListener(transport); + } + + // Provide defaults so that we can "include only relevant details in tests." + BinderServerTransportBuilder newBinderServerTransportBuilder() { + return new BinderServerTransportBuilder() + .setExecutorServicePool(new FixedObjectPool<>(executorService)) + .setAttributes(Attributes.EMPTY) + .setStreamTracerFactories(ImmutableList.of()) + .setBinderDecorator(OneWayBinderProxy.IDENTITY_DECORATOR) + .setCallbackBinder(mockBinder); } @Test - public void testSetupTransactionFailureCausesMultipleShutdowns_b153460678() throws Exception { + public void testSetupTransactionFailureReportsMultipleTerminations_b153460678() throws Exception { // Make the binder fail the setup transaction. - when(mockBinder.transact(anyInt(), any(Parcel.class), isNull(), anyInt())).thenReturn(false); - transport.setServerTransportListener(transportListener); + doThrow(new RemoteException()) + .when(mockBinder) + .transact(anyInt(), any(Parcel.class), isNull(), anyInt()); + transport = newBinderServerTransportBuilder().setCallbackBinder(mockBinder).build(); + shadowOf(Looper.getMainLooper()).idle(); + transport.start(transportListener); + + // Now shut it down externally *before* executing Runnables scheduled on the executor. + transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); + shadowOf(Looper.getMainLooper()).idle(); + + assertThat(transportListener.isTerminated()).isTrue(); + } + + @Test + public void testClientBinderIsDeadOnArrival() throws Exception { + transport = newBinderServerTransportBuilder() + .setCallbackBinder(new FakeDeadBinder()) + .build(); + transport.start(transportListener); + shadowOf(Looper.getMainLooper()).idle(); + + assertThat(transportListener.isTerminated()).isTrue(); + } + + @Test + public void testStartAfterShutdownAndIdle() throws Exception { + transport = newBinderServerTransportBuilder().build(); + transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); + shadowOf(Looper.getMainLooper()).idle(); + transport.start(transportListener); + shadowOf(Looper.getMainLooper()).idle(); + + assertThat(transportListener.isTerminated()).isTrue(); + } - // Now shut it down. + @Test + public void testStartAfterShutdownNoIdle() throws Exception { + transport = newBinderServerTransportBuilder().build(); transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); + transport.start(transportListener); shadowOf(Looper.getMainLooper()).idle(); - assertThat(transportListener.terminated).isTrue(); + assertThat(transportListener.isTerminated()).isTrue(); } - private static final class TestTransportListener implements ServerTransportListener { - - public boolean ready; - public boolean terminated; - - /** - * Called when a new stream was created by the remote client. - * - * @param stream the newly created stream. - * @param method the fully qualified method name being called on the server. - * @param headers containing metadata for the call. - */ - @Override - public void streamCreated(ServerStream stream, String method, Metadata headers) {} - - @Override - public Attributes transportReady(Attributes attributes) { - ready = true; - return attributes; + static class BinderServerTransportBuilder { + ObjectPool executorServicePool; + Attributes attributes; + List streamTracerFactories; + OneWayBinderProxy.Decorator binderDecorator; + IBinder callbackBinder; + + public BinderServerTransport build() { + return BinderServerTransport.create( + executorServicePool, attributes, streamTracerFactories, binderDecorator, callbackBinder); + } + + public BinderServerTransportBuilder setExecutorServicePool( + ObjectPool executorServicePool) { + this.executorServicePool = executorServicePool; + return this; + } + + public BinderServerTransportBuilder setAttributes(Attributes attributes) { + this.attributes = attributes; + return this; + } + + public BinderServerTransportBuilder setStreamTracerFactories( + List streamTracerFactories) { + this.streamTracerFactories = streamTracerFactories; + return this; + } + + public BinderServerTransportBuilder setBinderDecorator( + OneWayBinderProxy.Decorator binderDecorator) { + this.binderDecorator = binderDecorator; + return this; } - @Override - public void transportTerminated() { - checkState(!terminated, "Terminated twice"); - terminated = true; + public BinderServerTransportBuilder setCallbackBinder(IBinder callbackBinder) { + this.callbackBinder = callbackBinder; + return this; } } } diff --git a/binder/src/test/java/io/grpc/binder/internal/IntentNameResolverProviderTest.java b/binder/src/test/java/io/grpc/binder/internal/IntentNameResolverProviderTest.java new file mode 100644 index 00000000000..2809a72fee1 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/IntentNameResolverProviderTest.java @@ -0,0 +1,130 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import static android.os.Looper.getMainLooper; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.robolectric.Shadows.shadowOf; + +import android.app.Application; +import androidx.core.content.ContextCompat; +import androidx.test.core.app.ApplicationProvider; +import io.grpc.NameResolver; +import io.grpc.NameResolver.ResolutionResult; +import io.grpc.NameResolver.ServiceConfigParser; +import io.grpc.NameResolverProvider; +import io.grpc.SynchronizationContext; +import io.grpc.Uri; +import io.grpc.binder.ApiConstants; +import java.net.URI; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoTestRule; +import org.robolectric.RobolectricTestRunner; + +/** A test for IntentNameResolverProvider. */ +@RunWith(RobolectricTestRunner.class) +public final class IntentNameResolverProviderTest { + + private final Application appContext = ApplicationProvider.getApplicationContext(); + private final SynchronizationContext syncContext = newSynchronizationContext(); + private final NameResolver.Args args = newNameResolverArgs(); + + private NameResolverProvider provider; + + @Rule public MockitoTestRule mockitoTestRule = MockitoJUnit.testRule(this); + @Mock public NameResolver.Listener2 mockListener; + @Captor public ArgumentCaptor resultCaptor; + + @Before + public void setUp() { + provider = new IntentNameResolverProvider(); + } + + @Test + public void testProviderScheme_returnsIntentScheme() throws Exception { + assertThat(provider.getDefaultScheme()) + .isEqualTo(IntentNameResolverProvider.ANDROID_INTENT_SCHEME); + } + + @Test + public void testNoResolverForUnknownScheme_returnsNull() throws Exception { + assertThat(provider.newNameResolver(Uri.create("random://uri"), args)).isNull(); + } + + @Test + public void testResolutionWithBadUri_throwsIllegalArg() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> provider.newNameResolver(Uri.create("intent:xxx#Intent;e.x=1;end;"), args)); + } + + @Test + public void testResolverForIntentScheme_returnsResolver() throws Exception { + Uri uri = Uri.create("intent:#Intent;action=action;end"); + NameResolver resolver = provider.newNameResolver(uri, args); + assertThat(resolver).isNotNull(); + assertThat(resolver.getServiceAuthority()).isEqualTo("localhost"); + syncContext.execute(() -> resolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(resultCaptor.getValue().getAddressesOrError()).isNotNull(); + syncContext.execute(resolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + @Test + public void testResolverForIntentScheme_returnsResolver_javaNetUri() throws Exception { + URI uri = new URI("intent://authority/path#Intent;action=action;scheme=scheme;end"); + NameResolver resolver = provider.newNameResolver(uri, args); + assertThat(resolver).isNotNull(); + assertThat(resolver.getServiceAuthority()).isEqualTo("localhost"); + syncContext.execute(() -> resolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(resultCaptor.getValue().getAddressesOrError()).isNotNull(); + syncContext.execute(resolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + /** Returns a new test-specific {@link NameResolver.Args} instance. */ + private NameResolver.Args newNameResolverArgs() { + return NameResolver.Args.newBuilder() + .setDefaultPort(-1) + .setProxyDetector((target) -> null) // No proxies here. + .setSynchronizationContext(syncContext) + .setOffloadExecutor(ContextCompat.getMainExecutor(appContext)) + .setServiceConfigParser(mock(ServiceConfigParser.class)) + .setArg(ApiConstants.SOURCE_ANDROID_CONTEXT, appContext) + .build(); + } + + private static SynchronizationContext newSynchronizationContext() { + return new SynchronizationContext( + (thread, exception) -> { + throw new AssertionError(exception); + }); + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/IntentNameResolverTest.java b/binder/src/test/java/io/grpc/binder/internal/IntentNameResolverTest.java new file mode 100644 index 00000000000..b1bfcd4fd56 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/IntentNameResolverTest.java @@ -0,0 +1,531 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import static android.content.Intent.ACTION_PACKAGE_ADDED; +import static android.content.Intent.ACTION_PACKAGE_REPLACED; +import static android.os.Looper.getMainLooper; +import static android.os.Process.myUserHandle; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.robolectric.Shadows.shadowOf; + +import android.app.Application; +import android.content.ComponentName; +import android.content.Intent; +import android.content.IntentFilter; +import android.content.pm.ServiceInfo; +import android.net.Uri; +import android.os.UserHandle; +import android.os.UserManager; +import androidx.annotation.NonNull; +import androidx.core.content.ContextCompat; +import androidx.test.core.app.ApplicationProvider; +import com.google.common.collect.ImmutableList; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.NameResolver.ResolutionResult; +import io.grpc.NameResolver.ServiceConfigParser; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.SynchronizationContext; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.ApiConstants; +import java.lang.Thread.UncaughtExceptionHandler; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoTestRule; +import org.robolectric.RobolectricTestRunner; +import org.robolectric.annotation.Config; +import org.robolectric.shadows.ShadowPackageManager; + +/** A test for IntentNameResolverProvider. */ +@RunWith(RobolectricTestRunner.class) +public final class IntentNameResolverTest { + + private static final ComponentName SOME_COMPONENT_NAME = + new ComponentName("com.foo.bar", "SomeComponent"); + private static final ComponentName ANOTHER_COMPONENT_NAME = + new ComponentName("org.blah", "AnotherComponent"); + private final Application appContext = ApplicationProvider.getApplicationContext(); + private final SynchronizationContext syncContext = newSynchronizationContext(); + private final NameResolver.Args args = newNameResolverArgs().build(); + + private final ShadowPackageManager shadowPackageManager = + shadowOf(appContext.getPackageManager()); + + @Rule public MockitoTestRule mockitoTestRule = MockitoJUnit.testRule(this); + @Mock public NameResolver.Listener2 mockListener; + @Captor public ArgumentCaptor resultCaptor; + + @Test + public void testResolverForIntentScheme_returnsResolverWithLocalHostAuthority() throws Exception { + NameResolver resolver = newNameResolver(newIntent()); + assertThat(resolver).isNotNull(); + assertThat(resolver.getServiceAuthority()).isEqualTo("localhost"); + } + + @Test + public void testResolutionWithoutServicesAvailable_returnsUnimplemented() throws Exception { + NameResolver nameResolver = newNameResolver(newIntent()); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(resultCaptor.getValue().getAddressesOrError().getStatus().getCode()) + .isEqualTo(Status.UNIMPLEMENTED.getCode()); + } + + @Test + public void testResolutionWithMultipleServicesAvailable_returnsAndroidComponentAddresses() + throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + + // Adds another valid Service + shadowPackageManager.addServiceIfNotPresent(ANOTHER_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(ANOTHER_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = newNameResolver(intent); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly( + toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME)), + toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + @Test + public void testExplicitResolutionByComponent_returnsRestrictedResults() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + shadowPackageManager.addServiceIfNotPresent(ANOTHER_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(ANOTHER_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = + newNameResolver(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME)); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly(toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + @Test + public void testExplicitResolutionByPackage_returnsRestrictedResults() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + shadowPackageManager.addServiceIfNotPresent(ANOTHER_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(ANOTHER_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = + newNameResolver(intent.cloneFilter().setPackage(ANOTHER_COMPONENT_NAME.getPackageName())); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly(toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + @Test + public void testResolution_setsPreAuthEagAttribute() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = newNameResolver(intent); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly(toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME))); + assertThat( + getEagsOrThrow(resultCaptor.getValue()).stream() + .map(EquivalentAddressGroup::getAttributes) + .collect(toImmutableList()) + .get(0) + .get(ApiConstants.PRE_AUTH_SERVER_OVERRIDE)) + .isTrue(); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + @Test + public void testServiceRemoved_pushesUpdatedAndroidComponentAddresses() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + shadowPackageManager.addServiceIfNotPresent(ANOTHER_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(ANOTHER_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = newNameResolver(intent); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly( + toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME)), + toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + shadowPackageManager.removeService(ANOTHER_COMPONENT_NAME); + broadcastPackageChange(ACTION_PACKAGE_REPLACED, ANOTHER_COMPONENT_NAME.getPackageName()); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly(toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + + verifyNoMoreInteractions(mockListener); + assertThat(shadowOf(appContext).getRegisteredReceivers()).isEmpty(); + } + + @Test + @Config(sdk = 30) + public void testTargetAndroidUser_pushesUpdatedAddresses() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + NameResolver nameResolver = + newNameResolver( + intent, + newNameResolverArgs().setArg(ApiConstants.TARGET_ANDROID_USER, myUserHandle()).build()); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(resultCaptor.getValue().getAddressesOrError().getStatus().getCode()) + .isEqualTo(Status.UNIMPLEMENTED.getCode()); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + broadcastPackageChange(ACTION_PACKAGE_ADDED, SOME_COMPONENT_NAME.getPackageName()); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly( + ImmutableList.of( + AndroidComponentAddress.newBuilder() + .setTargetUser(myUserHandle()) + .setBindIntent(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME)) + .build())); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + + verifyNoMoreInteractions(mockListener); + assertThat(shadowOf(appContext).getRegisteredReceivers()).isEmpty(); + } + + @Test + @Config(sdk = 29) + public void testTargetAndroidUser_notSupported_throwsWithHelpfulMessage() throws Exception { + NameResolver.Args args = + newNameResolverArgs().setArg(ApiConstants.TARGET_ANDROID_USER, myUserHandle()).build(); + IllegalArgumentException iae = + assertThrows(IllegalArgumentException.class, () -> newNameResolver(newIntent(), args)); + assertThat(iae.getMessage()).contains("TARGET_ANDROID_USER"); + assertThat(iae.getMessage()).contains("SDK_INT >= R"); + } + + @Test + @Config(sdk = 29) + public void testServiceAppearsUponBootComplete_pushesUpdatedAndroidComponentAddresses() + throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + // Suppose this directBootAware=true Service appears in PackageManager before a user unlock. + shadowOf(appContext.getSystemService(UserManager.class)).setUserUnlocked(false); + ServiceInfo someServiceInfo = shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + someServiceInfo.directBootAware = true; + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = newNameResolver(intent); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly(toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME))); + + // TODO(b/331618070): Robolectric doesn't yet support ServiceInfo.directBootAware filtering. + // Simulate support by waiting for a user unlock to add this !directBootAware Service. + ServiceInfo anotherServiceInfo = + shadowPackageManager.addServiceIfNotPresent(ANOTHER_COMPONENT_NAME); + anotherServiceInfo.directBootAware = false; + shadowPackageManager.addIntentFilterForService(ANOTHER_COMPONENT_NAME, serviceIntentFilter); + + shadowOf(appContext.getSystemService(UserManager.class)).setUserUnlocked(true); + broadcastUserUnlocked(myUserHandle()); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly( + toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME)), + toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + verifyNoMoreInteractions(mockListener); + } + + @Test + public void testRefresh_returnsSameAndroidComponentAddresses() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + shadowPackageManager.addServiceIfNotPresent(ANOTHER_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(ANOTHER_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = newNameResolver(intent); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly( + toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME)), + toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + syncContext.execute(nameResolver::refresh); + shadowOf(getMainLooper()).idle(); + verify(mockListener, never()).onError(any()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly( + toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME)), + toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + assertThat(shadowOf(appContext).getRegisteredReceivers()).isEmpty(); + } + + @Test + public void testRefresh_collapsesMultipleRequestsIntoOneLookup() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = newNameResolver(intent); + syncContext.execute(() -> nameResolver.start(mockListener)); // Should kick off the 1st lookup. + syncContext.execute(nameResolver::refresh); // Should queue a lookup to run when 1st finishes. + syncContext.execute(nameResolver::refresh); // Should be ignored since a lookup is already Q'd. + syncContext.execute(nameResolver::refresh); // Also ignored. + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly(toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + private void broadcastPackageChange(String action, String pkgName) { + Intent broadcast = new Intent(); + broadcast.setAction(action); + broadcast.setData(Uri.parse("package:" + pkgName)); + appContext.sendBroadcast(broadcast); + } + + private void broadcastUserUnlocked(UserHandle userHandle) { + Intent unlockedBroadcast = new Intent(Intent.ACTION_USER_UNLOCKED); + unlockedBroadcast.putExtra(Intent.EXTRA_USER, userHandle); + appContext.sendBroadcast(unlockedBroadcast); + } + + @Test + public void testResolutionOnResultThrows_onErrorNotCalled() throws Exception { + RetainingUncaughtExceptionHandler exceptionHandler = new RetainingUncaughtExceptionHandler(); + SynchronizationContext syncContext = new SynchronizationContext(exceptionHandler); + Intent intent = newIntent(); + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, newFilterMatching(intent)); + + @SuppressWarnings("serial") + class SomeRuntimeException extends RuntimeException {} + doThrow(SomeRuntimeException.class).when(mockListener).onResult2(any()); + + NameResolver nameResolver = + newNameResolver( + intent, newNameResolverArgs().setSynchronizationContext(syncContext).build()); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener).onResult2(any()); + verify(mockListener, never()).onError(any()); + assertThat(exceptionHandler.uncaught).hasSize(1); + assertThat(exceptionHandler.uncaught.get(0)).isInstanceOf(SomeRuntimeException.class); + } + + private static Intent newIntent() { + Intent intent = new Intent(); + intent.setAction("test.action"); + intent.setData(Uri.parse("grpc:ServiceName")); + return intent; + } + + private static IntentFilter newFilterMatching(Intent intent) { + IntentFilter filter = new IntentFilter(); + if (intent.getAction() != null) { + filter.addAction(intent.getAction()); + } + Uri data = intent.getData(); + if (data != null) { + if (data.getScheme() != null) { + filter.addDataScheme(data.getScheme()); + } + if (data.getSchemeSpecificPart() != null) { + filter.addDataSchemeSpecificPart(data.getSchemeSpecificPart(), 0); + } + } + Set categories = intent.getCategories(); + if (categories != null) { + for (String category : categories) { + filter.addCategory(category); + } + } + return filter; + } + + private static List getEagsOrThrow(ResolutionResult result) { + StatusOr> eags = result.getAddressesOrError(); + if (!eags.hasValue()) { + throw eags.getStatus().asRuntimeException(); + } + return eags.getValue(); + } + + // Extracts just the addresses from 'result's EquivalentAddressGroups. + private static ImmutableList> getAddressesOrThrow(ResolutionResult result) { + return getEagsOrThrow(result).stream() + .map(EquivalentAddressGroup::getAddresses) + .collect(toImmutableList()); + } + + // Converts given Intents to a list of ACAs, for convenient comparison with getAddressesOrThrow(). + private static ImmutableList toAddressList(Intent... bindIntents) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (Intent bindIntent : bindIntents) { + builder.add(AndroidComponentAddress.forBindIntent(bindIntent)); + } + return builder.build(); + } + + private NameResolver newNameResolver(Intent targetIntent) { + return newNameResolver(targetIntent, args); + } + + private NameResolver newNameResolver(Intent targetIntent, NameResolver.Args args) { + return new IntentNameResolver(targetIntent, args); + } + + /** Returns a new test-specific {@link NameResolver.Args} instance. */ + private NameResolver.Args.Builder newNameResolverArgs() { + return NameResolver.Args.newBuilder() + .setDefaultPort(-1) + .setProxyDetector((target) -> null) // No proxies here. + .setSynchronizationContext(syncContext) + .setOffloadExecutor(ContextCompat.getMainExecutor(appContext)) + .setArg(ApiConstants.SOURCE_ANDROID_CONTEXT, appContext) + .setServiceConfigParser(mock(ServiceConfigParser.class)); + } + + /** + * Returns a test {@link SynchronizationContext}. + * + *

Exceptions will cause the test to fail with {@link AssertionError}. + */ + private static SynchronizationContext newSynchronizationContext() { + return new SynchronizationContext( + (thread, exception) -> { + throw new AssertionError(exception); + }); + } + + static final class RetainingUncaughtExceptionHandler implements UncaughtExceptionHandler { + final ArrayList uncaught = new ArrayList<>(); + + @Override + public void uncaughtException(@NonNull Thread t, @NonNull Throwable e) { + uncaught.add(e); + } + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java b/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java index 60e7c163105..c662cafe5fa 100644 --- a/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java +++ b/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java @@ -96,7 +96,7 @@ private static final class TestCallback implements ClientTransport.PingCallback private int numCallbacks; private boolean success; private boolean failure; - private Throwable failureException; + private Status failureStatus; private long roundtripTimeNanos; @Override @@ -107,10 +107,10 @@ public synchronized void onSuccess(long roundtripTimeNanos) { } @Override - public synchronized void onFailure(Throwable failureException) { + public synchronized void onFailure(Status failureStatus) { numCallbacks += 1; failure = true; - this.failureException = failureException; + this.failureStatus = failureStatus; } public void assertNotCalled() { @@ -130,13 +130,13 @@ public void assertSuccess(long expectRoundTripTimeNanos) { public void assertFailure(Status status) { assertThat(numCallbacks).isEqualTo(1); assertThat(failure).isTrue(); - assertThat(((StatusException) failureException).getStatus()).isSameInstanceAs(status); + assertThat(failureStatus).isSameInstanceAs(status); } public void assertFailure(Status.Code statusCode) { assertThat(numCallbacks).isEqualTo(1); assertThat(failure).isTrue(); - assertThat(((StatusException) failureException).getStatus().getCode()).isEqualTo(statusCode); + assertThat(failureStatus.getCode()).isEqualTo(statusCode); } } } diff --git a/binder/src/test/java/io/grpc/binder/internal/RobolectricBinderTransportTest.java b/binder/src/test/java/io/grpc/binder/internal/RobolectricBinderTransportTest.java new file mode 100644 index 00000000000..8282f5e1025 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/RobolectricBinderTransportTest.java @@ -0,0 +1,436 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static android.os.IBinder.FLAG_ONEWAY; +import static android.os.Process.myUid; +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.grpc.binder.internal.BinderTransport.REMOTE_UID; +import static io.grpc.binder.internal.BinderTransport.SETUP_TRANSPORT; +import static io.grpc.binder.internal.BinderTransport.SHUTDOWN_TRANSPORT; +import static io.grpc.binder.internal.BinderTransport.WIRE_FORMAT_VERSION; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.Assume.assumeTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; +import static org.robolectric.Shadows.shadowOf; + +import android.app.Application; +import android.content.Intent; +import android.content.pm.ApplicationInfo; +import android.content.pm.PackageInfo; +import android.content.pm.ServiceInfo; +import android.os.Binder; +import android.os.Parcel; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.core.content.pm.ApplicationInfoBuilder; +import androidx.test.core.content.pm.PackageInfoBuilder; +import com.google.common.collect.ImmutableList; +import com.google.common.truth.TruthJUnit; +import io.grpc.Attributes; +import io.grpc.InternalChannelz.SocketStats; +import io.grpc.ServerStreamTracer; +import io.grpc.Status; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.ApiConstants; +import io.grpc.binder.AsyncSecurityPolicy; +import io.grpc.binder.SecurityPolicies; +import io.grpc.binder.internal.SettableAsyncSecurityPolicy.AuthRequest; +import io.grpc.internal.AbstractTransportTest; +import io.grpc.internal.ClientTransport; +import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; +import io.grpc.internal.ConnectionClientTransport; +import io.grpc.internal.DisconnectError; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.InternalServer; +import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.MockServerTransportListener; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.robolectric.ParameterizedRobolectricTestRunner; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameter; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameters; +import org.robolectric.annotation.LooperMode; +import org.robolectric.annotation.LooperMode.Mode; +import org.robolectric.shadows.ShadowBinder; + +/** + * All of the AbstractTransportTest cases applied to {@link BinderTransport} running in a + * Robolectric environment. + * + *

Runs much faster than BinderTransportTest and doesn't require an Android device/emulator. + * Somewhat less realistic but allows simulating behavior that would be difficult or impossible with + * real Android. + * + *

NB: Unlike most robolectric tests, we run in {@link LooperMode.Mode#INSTRUMENTATION_TEST}, + * meaning test cases don't run on the main thread. This supports the AbstractTransportTest approach + * where the test thread frequently blocks waiting for transport state changes to take effect. + */ +@RunWith(ParameterizedRobolectricTestRunner.class) +@LooperMode(Mode.INSTRUMENTATION_TEST) +public final class RobolectricBinderTransportTest extends AbstractTransportTest { + + static final int SERVER_APP_UID = 11111; + static final int EPHEMERAL_SERVER_UID = 22222; // UID of isolated server process. + + private final Application application = ApplicationProvider.getApplicationContext(); + private final ObjectPool executorServicePool = + SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); + private final ObjectPool offloadExecutorPool = + SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR); + private final ObjectPool serverExecutorPool = + SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR); + + @Rule public MockitoRule mocks = MockitoJUnit.rule(); + + @Mock AsyncSecurityPolicy mockClientSecurityPolicy; + + @Captor ArgumentCaptor statusCaptor; + + ApplicationInfo serverAppInfo; + PackageInfo serverPkgInfo; + ServiceInfo serviceInfo; + + private int nextServerAddress; + + @Parameter(value = 0) + public boolean preAuthServersParam; + + @Parameter(value = 1) + public boolean useLegacyAuthStrategy; + + @Parameters(name = "preAuthServersParam={0};useLegacyAuthStrategy={1}") + public static ImmutableList data() { + return ImmutableList.of( + new Object[] {false, false}, + new Object[] {false, true}, + new Object[] {true, false}, + new Object[] {true, true}); + } + + @Override + public void setUp() { + serverAppInfo = + ApplicationInfoBuilder.newBuilder().setPackageName("the.server.package").build(); + serverAppInfo.uid = myUid(); + serverPkgInfo = + PackageInfoBuilder.newBuilder() + .setPackageName(serverAppInfo.packageName) + .setApplicationInfo(serverAppInfo) + .build(); + shadowOf(application.getPackageManager()).installPackage(serverPkgInfo); + + serviceInfo = new ServiceInfo(); + serviceInfo.name = "SomeService"; + serviceInfo.packageName = serverAppInfo.packageName; + serviceInfo.applicationInfo = serverAppInfo; + shadowOf(application.getPackageManager()).addOrUpdateService(serviceInfo); + + super.setUp(); + } + + @Before + public void requestRealisticBindServiceBehavior() { + shadowOf(application).setBindServiceCallsOnServiceConnectedDirectly(false); + shadowOf(application).setUnbindServiceCallsOnServiceDisconnected(false); + } + + @Override + protected InternalServer newServer(List streamTracerFactories) { + AndroidComponentAddress listenAddr = + AndroidComponentAddress.forBindIntent( + new Intent() + .setClassName(serviceInfo.packageName, serviceInfo.name) + .setAction("io.grpc.action.BIND." + nextServerAddress++)); + + BinderServer binderServer = + new BinderServer.Builder() + .setListenAddress(listenAddr) + .setExecutorPool(serverExecutorPool) + .setExecutorServicePool(executorServicePool) + .setStreamTracerFactories(streamTracerFactories) + .build(); + + shadowOf(application.getPackageManager()).addServiceIfNotPresent(listenAddr.getComponent()); + shadowOf(application) + .setComponentNameAndServiceForBindServiceForIntent( + listenAddr.asBindIntent(), listenAddr.getComponent(), binderServer.getHostBinder()); + return binderServer; + } + + @Override + protected InternalServer newServer( + int port, List streamTracerFactories) { + if (port > 0) { + // TODO: TCP ports have no place in an *abstract* transport test. Replace with SocketAddress. + throw new UnsupportedOperationException(); + } + return newServer(streamTracerFactories); + } + + BinderClientTransportFactory.Builder newClientTransportFactoryBuilder() { + return new BinderClientTransportFactory.Builder() + .setPreAuthorizeServers(preAuthServersParam) + .setUseLegacyAuthStrategy(useLegacyAuthStrategy) + .setSourceContext(application) + .setScheduledExecutorPool(executorServicePool) + .setOffloadExecutorPool(offloadExecutorPool); + } + + BinderClientTransportBuilder newClientTransportBuilder() { + return new BinderClientTransportBuilder() + .setFactory(newClientTransportFactoryBuilder().buildClientTransportFactory()) + .setServerAddress(server.getListenSocketAddress()); + } + + @Override + protected ManagedClientTransport newClientTransport(InternalServer server) { + ClientTransportOptions options = new ClientTransportOptions(); + options.setEagAttributes(eagAttrs()); + options.setChannelLogger(transportLogger()); + + return newClientTransportBuilder() + .setServerAddress(server.getListenSocketAddress()) + .setOptions(options) + .build(); + } + + @Override + protected String testAuthority(InternalServer server) { + return ((AndroidComponentAddress) server.getListenSocketAddress()).getAuthority(); + } + + @Test + public void clientAuthorizesServerUidsInOrder() throws Exception { + // TODO(jdcormie): In real Android, Binder#getCallingUid is thread-local but Robolectric only + // lets us fake value this *globally*. So the ShadowBinder#setCallingUid() here unrealistically + // affects the server's view of the client's uid too. For now this doesn't matter because this + // test never exercises server SecurityPolicy. + ShadowBinder.setCallingUid(EPHEMERAL_SERVER_UID); + + serverPkgInfo.applicationInfo.uid = SERVER_APP_UID; + shadowOf(application.getPackageManager()).installPackage(serverPkgInfo); + shadowOf(application.getPackageManager()).addOrUpdateService(serviceInfo); + server = newServer(ImmutableList.of()); + server.start(serverListener); + + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + client = + newClientTransportBuilder() + .setFactory( + newClientTransportFactoryBuilder() + .setSecurityPolicy(securityPolicy) + .buildClientTransportFactory()) + .build(); + runIfNotNull(client.start(mockClientTransportListener)); + + if (preAuthServersParam) { + AuthRequest preAuthRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_MS, MILLISECONDS); + assertThat(preAuthRequest.uid).isEqualTo(SERVER_APP_UID); + verify(mockClientTransportListener, never()).transportReady(); + preAuthRequest.setResult(Status.OK); + } + + AuthRequest authRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_MS, MILLISECONDS); + if (useLegacyAuthStrategy) { + assertThat(authRequest.uid).isEqualTo(EPHEMERAL_SERVER_UID); + } else { + assertThat(authRequest.uid).isEqualTo(SERVER_APP_UID); + } + verify(mockClientTransportListener, never()).transportReady(); + authRequest.setResult(Status.OK); + + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + } + + @Test + public void eagAttributeCanOverrideChannelPreAuthServerSetting() throws Exception { + server.start(serverListener); + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + ClientTransportOptions options = new ClientTransportOptions(); + options.setEagAttributes( + Attributes.newBuilder().set(ApiConstants.PRE_AUTH_SERVER_OVERRIDE, true).build()); + client = + newClientTransportBuilder() + .setOptions(options) + .setFactory( + newClientTransportFactoryBuilder() + .setPreAuthorizeServers(preAuthServersParam) // To be overridden. + .setSecurityPolicy(securityPolicy) + .buildClientTransportFactory()) + .build(); + runIfNotNull(client.start(mockClientTransportListener)); + + AuthRequest preAuthRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_MS, MILLISECONDS); + verify(mockClientTransportListener, never()).transportReady(); + preAuthRequest.setResult(Status.OK); + + AuthRequest authRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_MS, MILLISECONDS); + verify(mockClientTransportListener, never()).transportReady(); + authRequest.setResult(Status.OK); + + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + } + + @Test + public void clientIgnoresDuplicateSetupTransaction() throws Exception { + server.start(serverListener); + client = + newClientTransportBuilder() + .setFactory( + newClientTransportFactoryBuilder() + .setSecurityPolicy(SecurityPolicies.internalOnly()) + .buildClientTransportFactory()) + .build(); + runIfNotNull(client.start(mockClientTransportListener)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + + assertThat(((ConnectionClientTransport) client).getAttributes().get(REMOTE_UID)) + .isEqualTo(myUid()); + + Parcel setupParcel = Parcel.obtain(); + try { + setupParcel.writeInt(WIRE_FORMAT_VERSION); + setupParcel.writeStrongBinder(new Binder()); + setupParcel.setDataPosition(0); + ShadowBinder.setCallingUid(1 + myUid()); + ((BinderClientTransport) client).handleTransaction(SETUP_TRANSPORT, setupParcel); + } finally { + ShadowBinder.setCallingUid(myUid()); + setupParcel.recycle(); + } + + assertThat(((ConnectionClientTransport) client).getAttributes().get(REMOTE_UID)) + .isEqualTo(myUid()); + } + + @Test + public void clientIgnoresTransactionFromNonServerUids() throws Exception { + server.start(serverListener); + + // This test is not applicable to the new auth strategy which keeps the client Binder a secret. + assumeTrue(useLegacyAuthStrategy); + + client = newClientTransport(server); + startTransport(client, mockClientTransportListener); + + int serverUid = ((ConnectionClientTransport) client).getAttributes().get(REMOTE_UID); + int someOtherUid = 1 + serverUid; + sendShutdownTransportTransactionAsUid(client, someOtherUid); + + // Demonstrate that the transport is still working and that shutdown transaction was ignored. + ClientTransport.PingCallback mockPingCallback = mock(ClientTransport.PingCallback.class); + client.ping(mockPingCallback, directExecutor()); + verify(mockPingCallback, timeout(TIMEOUT_MS)).onSuccess(anyLong()); + + // Try again as the expected uid to demonstrate that this wasn't ignored for some other reason. + sendShutdownTransportTransactionAsUid(client, serverUid); + + verify(mockClientTransportListener, timeout(TIMEOUT_MS)) + .transportShutdown(statusCaptor.capture(), any(DisconnectError.class)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(statusCaptor.getValue().getDescription()).contains("shutdown"); + } + + static void sendShutdownTransportTransactionAsUid(ClientTransport client, int sendingUid) { + int originalUid = Binder.getCallingUid(); + try { + ShadowBinder.setCallingUid(sendingUid); + ((BinderClientTransport) client) + .getIncomingBinderForTesting() + .onTransact(SHUTDOWN_TRANSPORT, null, null, FLAG_ONEWAY); + } finally { + ShadowBinder.setCallingUid(originalUid); + } + } + + @Test + public void clientReportsAuthzErrorToServer() throws Exception { + server.start(serverListener); + client = + newClientTransportBuilder() + .setFactory( + newClientTransportFactoryBuilder() + .setSecurityPolicy(SecurityPolicies.permissionDenied("test")) + .buildClientTransportFactory()) + .build(); + runIfNotNull(client.start(mockClientTransportListener)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)) + .transportShutdown(statusCaptor.capture(), any(DisconnectError.class)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.PERMISSION_DENIED); + + // Client doesn't tell the server in this case by design -- we don't even want to start it! + TruthJUnit.assume().that(preAuthServersParam).isFalse(); + // Similar story here. The client won't send a setup transaction to an unauthorized server. + TruthJUnit.assume().that(useLegacyAuthStrategy).isTrue(); + + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, MILLISECONDS); + serverTransportListener.waitForTermination(TIMEOUT_MS, MILLISECONDS); + assertThat(serverTransportListener.isTerminated()).isTrue(); + } + + @Test + @Override + // We don't quite pass the official/abstract version of this test yet because + // today's binder client and server transports have different ideas of each others' address. + // TODO(#12347): Remove this @Override once this difference is resolved. + public void socketStats() throws Exception { + server.start(serverListener); + ManagedClientTransport client = newClientTransport(server); + startTransport(client, mockClientTransportListener); + + SocketStats clientSocketStats = client.getStats().get(); + assertThat(clientSocketStats.local).isInstanceOf(AndroidComponentAddress.class); + assertThat(((AndroidComponentAddress) clientSocketStats.remote).getPackage()) + .isEqualTo(((AndroidComponentAddress) server.getListenSocketAddress()).getPackage()); + + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, MILLISECONDS); + SocketStats serverSocketStats = serverTransportListener.transport.getStats().get(); + assertThat(serverSocketStats.local).isEqualTo(server.getListenSocketAddress()); + assertThat(serverSocketStats.remote).isEqualTo(new BoundClientAddress(myUid())); + } + + @Test + @Ignore("See BinderTransportTest#flowControlPushBack") + @Override + public void flowControlPushBack() {} + + @Test + @Ignore("See BinderTransportTest#serverAlreadyListening") + @Override + public void serverAlreadyListening() {} +} diff --git a/binder/src/test/java/io/grpc/binder/internal/ServiceBindingTest.java b/binder/src/test/java/io/grpc/binder/internal/ServiceBindingTest.java index b44692f560d..0f57b6f8a30 100644 --- a/binder/src/test/java/io/grpc/binder/internal/ServiceBindingTest.java +++ b/binder/src/test/java/io/grpc/binder/internal/ServiceBindingTest.java @@ -19,15 +19,17 @@ import static android.content.Context.BIND_AUTO_CREATE; import static android.os.Looper.getMainLooper; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import static org.robolectric.Shadows.shadowOf; -import static org.robolectric.annotation.LooperMode.Mode.PAUSED; import android.app.Application; import android.app.admin.DevicePolicyManager; import android.content.ComponentName; import android.content.Context; import android.content.Intent; +import android.content.pm.ServiceInfo; +import android.os.Build; import android.os.IBinder; import android.os.Parcel; import android.os.UserHandle; @@ -35,6 +37,7 @@ import androidx.test.core.app.ApplicationProvider; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.StatusException; import io.grpc.binder.BinderChannelCredentials; import io.grpc.binder.internal.Bindable.Observer; import java.util.Arrays; @@ -48,11 +51,8 @@ import org.mockito.junit.MockitoRule; import org.robolectric.RobolectricTestRunner; import org.robolectric.annotation.Config; -import org.robolectric.annotation.LooperMode; import org.robolectric.shadows.ShadowApplication; -import org.robolectric.shadows.ShadowDevicePolicyManager; -@LooperMode(PAUSED) @RunWith(RobolectricTestRunner.class) public final class ServiceBindingTest { @@ -62,6 +62,7 @@ public final class ServiceBindingTest { private Application appContext; private ComponentName serviceComponent; + private ServiceInfo serviceInfo = new ServiceInfo(); private ShadowApplication shadowApplication; private TestObserver observer; private ServiceBinding binding; @@ -70,13 +71,17 @@ public final class ServiceBindingTest { public void setUp() { appContext = ApplicationProvider.getApplicationContext(); serviceComponent = new ComponentName("DUMMY", "SERVICE"); + serviceInfo.packageName = serviceComponent.getPackageName(); + serviceInfo.name = serviceComponent.getClassName(); observer = new TestObserver(); shadowApplication = shadowOf(appContext); shadowApplication.setComponentNameAndServiceForBindService(serviceComponent, mockBinder); + shadowOf(appContext.getPackageManager()).addOrUpdateService(serviceInfo); // Don't call onServiceDisconnected() upon unbindService(), just like the real Android doesn't. shadowApplication.setUnbindServiceCallsOnServiceDisconnected(false); + shadowApplication.setBindServiceCallsOnServiceConnectedDirectly(false); binding = newBuilder().build(); shadowOf(getMainLooper()).idle(); @@ -110,6 +115,32 @@ public void testBind() throws Exception { assertThat(binding.isSourceContextCleared()).isFalse(); } + @Test + public void testGetConnectedServiceInfo() throws Exception { + binding = newBuilder().setTargetComponent(serviceComponent).build(); + binding.bind(); + shadowOf(getMainLooper()).idle(); + + assertThat(observer.gotBoundEvent).isTrue(); + + ServiceInfo serviceInfo = binding.getConnectedServiceInfo(); + assertThat(serviceInfo.name).isEqualTo(serviceComponent.getClassName()); + assertThat(serviceInfo.packageName).isEqualTo(serviceComponent.getPackageName()); + } + + @Test + public void testGetConnectedServiceInfoThrows() throws Exception { + binding = newBuilder().setTargetComponent(serviceComponent).build(); + binding.bind(); + shadowOf(getMainLooper()).idle(); + + assertThat(observer.gotBoundEvent).isTrue(); + shadowOf(appContext.getPackageManager()).removeService(serviceComponent); + + StatusException se = assertThrows(StatusException.class, binding::getConnectedServiceInfo); + assertThat(se.getStatus().getCode()).isEqualTo(Code.UNIMPLEMENTED); + } + @Test public void testBindingIntent() throws Exception { shadowApplication.setComponentNameAndServiceForBindService(null, null); @@ -279,16 +310,112 @@ public void testBindWithTargetUserHandle() throws Exception { assertThat(binding.isSourceContextCleared()).isFalse(); } + @Test + public void testResolve() throws Exception { + serviceInfo.processName = "x"; // ServiceInfo has no equals() so look for one distinctive field. + shadowOf(appContext.getPackageManager()).addOrUpdateService(serviceInfo); + ServiceInfo resolvedServiceInfo = binding.resolve(); + assertThat(resolvedServiceInfo.processName).isEqualTo(serviceInfo.processName); + } + + @Test + @Config(sdk = 33) + public void testResolveWithTargetUserHandle() throws Exception { + serviceInfo.processName = "x"; // ServiceInfo has no equals() so look for one distinctive field. + // Robolectric just ignores the user arg to resolveServiceAsUser() so this is all we can do. + shadowOf(appContext.getPackageManager()).addOrUpdateService(serviceInfo); + binding = newBuilder().setTargetUserHandle(generateUserHandle(/* userId= */ 0)).build(); + ServiceInfo resolvedServiceInfo = binding.resolve(); + assertThat(resolvedServiceInfo.processName).isEqualTo(serviceInfo.processName); + } + + @Test + @Config(sdk = 29) + public void testResolveWithUnsupportedTargetUserHandle() throws Exception { + binding = newBuilder().setTargetUserHandle(generateUserHandle(/* userId= */ 0)).build(); + StatusException statusException = assertThrows(StatusException.class, binding::resolve); + assertThat(statusException.getStatus().getCode()).isEqualTo(Code.INTERNAL); + assertThat(statusException.getStatus().getDescription()).contains("SDK_INT >= R"); + } + + @Test + public void testResolveNonExistentServiceThrows() throws Exception { + ComponentName doesNotExistService = new ComponentName("does.not.exist", "NoService"); + binding = newBuilder().setTargetComponent(doesNotExistService).build(); + StatusException statusException = assertThrows(StatusException.class, binding::resolve); + assertThat(statusException.getStatus().getCode()).isEqualTo(Code.UNIMPLEMENTED); + assertThat(statusException.getStatus().getDescription()).contains("does.not.exist"); + } + + @Test + @Config(sdk = 33) + public void testResolveNonExistentServiceWithTargetUserThrows() throws Exception { + ComponentName doesNotExistService = new ComponentName("does.not.exist", "NoService"); + binding = + newBuilder() + .setTargetUserHandle(generateUserHandle(/* userId= */ 12345)) + .setTargetComponent(doesNotExistService) + .build(); + StatusException statusException = assertThrows(StatusException.class, binding::resolve); + assertThat(statusException.getStatus().getCode()).isEqualTo(Code.UNIMPLEMENTED); + assertThat(statusException.getStatus().getDescription()).contains("does.not.exist"); + assertThat(statusException.getStatus().getDescription()).contains("12345"); + } + + @Test + @Config(sdk = 30) + public void testBindService_doesNotThrowInternalErrorWhenSdkAtLeastR() { + UserHandle userHandle = generateUserHandle(/* userId= */ 12345); + binding = newBuilder().setTargetUserHandle(userHandle).build(); + binding.bind(); + shadowOf(getMainLooper()).idle(); + + assertThat(Build.VERSION.SDK_INT).isEqualTo(Build.VERSION_CODES.R); + assertThat(observer.unboundReason).isNull(); + } + + @Test + @Config(sdk = 28) + public void testBindServiceAsUser_returnsErrorWhenSdkBelowR() { + UserHandle userHandle = generateUserHandle(/* userId= */ 12345); + binding = newBuilder().setTargetUserHandle(userHandle).build(); + binding.bind(); + shadowOf(getMainLooper()).idle(); + + assertThat(observer.unboundReason.getCode()).isEqualTo(Code.INTERNAL); + assertThat(observer.unboundReason.getDescription()) + .isEqualTo("Cross user Channel requires Android R+"); + } + + @Test + @Config(sdk = 28) + public void testDevicePolicyBlind_returnsErrorWhenSdkBelowR() { + ComponentName adminComponent = new ComponentName(appContext, "DevicePolicyAdmin"); + UserHandle user10 = generateUserHandle(/* userId= */ 10); + allowBindDeviceAdminForUser(appContext, adminComponent, user10); + binding = + newBuilder() + .setTargetUserHandle(user10) + .setChannelCredentials(BinderChannelCredentials.forDevicePolicyAdmin(adminComponent)) + .build(); + binding.bind(); + shadowOf(getMainLooper()).idle(); + + assertThat(observer.unboundReason.getCode()).isEqualTo(Code.INTERNAL); + assertThat(observer.unboundReason.getDescription()) + .isEqualTo("Device policy admin binding requires Android R+"); + } + @Test @Config(sdk = 30) public void testBindWithDeviceAdmin() throws Exception { - String deviceAdminClassName = "DevicePolicyAdmin"; - ComponentName adminComponent = new ComponentName(appContext, deviceAdminClassName); - allowBindDeviceAdminForUser(appContext, adminComponent, /* userId= */ 0); + ComponentName adminComponent = new ComponentName(appContext, "DevicePolicyAdmin"); + UserHandle user0 = generateUserHandle(/* userId= */ 0); + allowBindDeviceAdminForUser(appContext, adminComponent, user0); binding = newBuilder() - .setTargetUserHandle(UserHandle.getUserHandleForUid(/* userId= */ 0)) - .setTargetUserHandle(generateUserHandle(/* userId= */ 0)) + .setTargetUserHandle(user0) + .setTargetComponent(serviceComponent) .setChannelCredentials(BinderChannelCredentials.forDevicePolicyAdmin(adminComponent)) .build(); shadowOf(getMainLooper()).idle(); @@ -301,6 +428,10 @@ public void testBindWithDeviceAdmin() throws Exception { assertThat(observer.binder).isSameInstanceAs(mockBinder); assertThat(observer.gotUnboundEvent).isFalse(); assertThat(binding.isSourceContextCleared()).isFalse(); + + ServiceInfo serviceInfo = binding.getConnectedServiceInfo(); + assertThat(serviceInfo.name).isEqualTo(serviceComponent.getClassName()); + assertThat(serviceInfo.packageName).isEqualTo(serviceComponent.getPackageName()); } private void assertNoLockHeld() { @@ -316,15 +447,10 @@ private void assertNoLockHeld() { } private static void allowBindDeviceAdminForUser( - Context context, ComponentName admin, int userId) { - ShadowDevicePolicyManager devicePolicyManager = - shadowOf(context.getSystemService(DevicePolicyManager.class)); - devicePolicyManager.setDeviceOwner(admin); - devicePolicyManager.setBindDeviceAdminTargetUsers( - Arrays.asList(UserHandle.getUserHandleForUid(userId))); - shadowOf((DevicePolicyManager) context.getSystemService(Context.DEVICE_POLICY_SERVICE)); - devicePolicyManager.setDeviceOwner(admin); - devicePolicyManager.setBindDeviceAdminTargetUsers(Arrays.asList(generateUserHandle(userId))); + Context context, ComponentName admin, UserHandle user) { + DevicePolicyManager devicePolicyManager = context.getSystemService(DevicePolicyManager.class); + shadowOf(devicePolicyManager).setBindDeviceAdminTargetUsers(Arrays.asList(user)); + shadowOf(devicePolicyManager).setDeviceOwner(admin); } /** Generate UserHandles the hard way. */ diff --git a/binder/src/test/java/io/grpc/binder/internal/SimplePromiseTest.java b/binder/src/test/java/io/grpc/binder/internal/SimplePromiseTest.java new file mode 100644 index 00000000000..6486ff5e8a1 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/SimplePromiseTest.java @@ -0,0 +1,143 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import io.grpc.binder.internal.SimplePromise.Listener; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public final class SimplePromiseTest { + + private static final String FULFILLED_VALUE = "a fulfilled value"; + + @Mock private Listener mockListener1; + @Mock private Listener mockListener2; + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + private SimplePromise promise = new SimplePromise<>(); + + @Before + public void setUp() { + } + + @Test + public void get_beforeFulfilled_throws() { + IllegalStateException e = assertThrows(IllegalStateException.class, () -> promise.get()); + assertThat(e).hasMessageThat().isEqualTo("Not yet set!"); + } + + @Test + public void get_afterFulfilled_returnsValue() { + promise.set(FULFILLED_VALUE); + assertThat(promise.get()).isEqualTo(FULFILLED_VALUE); + } + + @Test + public void set_withNull_throws() { + assertThrows(NullPointerException.class, () -> promise.set(null)); + } + + @Test + public void set_calledTwice_throws() { + promise.set(FULFILLED_VALUE); + IllegalStateException e = + assertThrows(IllegalStateException.class, () -> promise.set("another value")); + assertThat(e).hasMessageThat().isEqualTo("Already set!"); + } + + @Test + public void runWhenSet_beforeFulfill_listenerIsNotifiedUponSet() { + promise.runWhenSet(mockListener1); + + // Should not have been called yet. + verify(mockListener1, never()).notify(FULFILLED_VALUE); + + promise.set(FULFILLED_VALUE); + + // Now it should be called. + verify(mockListener1, times(1)).notify(FULFILLED_VALUE); + } + + @Test + public void runWhenSet_afterSet_listenerIsNotifiedImmediately() { + promise.set(FULFILLED_VALUE); + promise.runWhenSet(mockListener1); + + // Should have been called immediately. + verify(mockListener1, times(1)).notify(FULFILLED_VALUE); + } + + @Test + public void multipleListeners_addedBeforeSet_allNotifiedInOrder() { + promise.runWhenSet(mockListener1); + promise.runWhenSet(mockListener2); + + promise.set(FULFILLED_VALUE); + + InOrder inOrder = inOrder(mockListener1, mockListener2); + inOrder.verify(mockListener1).notify(FULFILLED_VALUE); + inOrder.verify(mockListener2).notify(FULFILLED_VALUE); + } + + @Test + public void listenerThrows_duringSet_propagatesException() { + // A listener that will throw when notified. + Listener throwingListener = + (value) -> { + throw new UnsupportedOperationException("Listener failed"); + }; + + promise.runWhenSet(throwingListener); + + // Fulfilling the promise should now throw the exception from the listener. + UnsupportedOperationException e = + assertThrows(UnsupportedOperationException.class, () -> promise.set(FULFILLED_VALUE)); + assertThat(e).hasMessageThat().isEqualTo("Listener failed"); + } + + @Test + public void listenerThrows_whenAddedAfterSet_propagatesException() { + promise.set(FULFILLED_VALUE); + + // A listener that will throw when notified. + Listener throwingListener = + (value) -> { + throw new UnsupportedOperationException("Listener failed"); + }; + + // Running the listener should throw immediately because the promise is already fulfilled. + UnsupportedOperationException e = + assertThrows( + UnsupportedOperationException.class, () -> promise.runWhenSet(throwingListener)); + assertThat(e).hasMessageThat().isEqualTo("Listener failed"); + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/TransactionUtilsTest.java b/binder/src/test/java/io/grpc/binder/internal/TransactionUtilsTest.java new file mode 100644 index 00000000000..44a3ce3ef26 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/TransactionUtilsTest.java @@ -0,0 +1,70 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.binder.internal.TransactionUtils.newCallerFilteringHandler; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import android.os.Binder; +import android.os.Parcel; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.robolectric.RobolectricTestRunner; +import org.robolectric.shadows.ShadowBinder; + +@RunWith(RobolectricTestRunner.class) +public final class TransactionUtilsTest { + + @Rule public MockitoRule mocks = MockitoJUnit.rule(); + + @Mock LeakSafeOneWayBinder.TransactionHandler mockHandler; + + @Test + public void shouldIgnoreTransactionFromWrongUid() { + Parcel p = Parcel.obtain(); + int originalUid = Binder.getCallingUid(); + try { + when(mockHandler.handleTransaction(eq(1234), same(p))).thenReturn(true); + LeakSafeOneWayBinder.TransactionHandler uid100OnlyHandler = + newCallerFilteringHandler(1000, mockHandler); + + ShadowBinder.setCallingUid(9999); + boolean result = uid100OnlyHandler.handleTransaction(1234, p); + assertThat(result).isFalse(); + verify(mockHandler, never()).handleTransaction(anyInt(), any()); + + ShadowBinder.setCallingUid(1000); + result = uid100OnlyHandler.handleTransaction(1234, p); + assertThat(result).isTrue(); + verify(mockHandler).handleTransaction(1234, p); + } finally { + ShadowBinder.setCallingUid(originalUid); + p.recycle(); + } + } +} diff --git a/binder/src/testFixtures/java/io/grpc/binder/internal/BinderClientTransportBuilder.java b/binder/src/testFixtures/java/io/grpc/binder/internal/BinderClientTransportBuilder.java new file mode 100644 index 00000000000..f732ff64663 --- /dev/null +++ b/binder/src/testFixtures/java/io/grpc/binder/internal/BinderClientTransportBuilder.java @@ -0,0 +1,61 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.grpc.ChannelLogger; +import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; +import io.grpc.internal.TestUtils.NoopChannelLogger; +import java.net.SocketAddress; + +/** + * Helps unit tests create {@link BinderClientTransport} instances without having to mention + * irrelevant details (go/tott/719). + */ +public class BinderClientTransportBuilder { + private BinderClientTransportFactory factory; + private SocketAddress serverAddress; + private ChannelLogger channelLogger = new NoopChannelLogger(); + private io.grpc.internal.ClientTransportFactory.ClientTransportOptions options = + new ClientTransportOptions(); + + public BinderClientTransportBuilder setServerAddress(SocketAddress serverAddress) { + this.serverAddress = checkNotNull(serverAddress); + return this; + } + + public BinderClientTransportBuilder setChannelLogger(ChannelLogger channelLogger) { + this.channelLogger = checkNotNull(channelLogger); + return this; + } + + public BinderClientTransportBuilder setOptions(ClientTransportOptions options) { + this.options = checkNotNull(options); + return this; + } + + public BinderClientTransportBuilder setFactory(BinderClientTransportFactory factory) { + this.factory = checkNotNull(factory); + return this; + } + + public BinderClientTransport build() { + return factory.newClientTransport( + checkNotNull(serverAddress), checkNotNull(options), checkNotNull(channelLogger)); + } +} diff --git a/binder/src/testFixtures/java/io/grpc/binder/internal/FakeDeadBinder.java b/binder/src/testFixtures/java/io/grpc/binder/internal/FakeDeadBinder.java new file mode 100644 index 00000000000..5bce7498c4b --- /dev/null +++ b/binder/src/testFixtures/java/io/grpc/binder/internal/FakeDeadBinder.java @@ -0,0 +1,74 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import android.os.DeadObjectException; +import android.os.IBinder; +import android.os.IInterface; +import android.os.Parcel; +import android.os.RemoteException; +import java.io.FileDescriptor; + +/** An {@link IBinder} that behaves as if its hosting process has died, for testing. */ +public class FakeDeadBinder implements IBinder { + @Override + public boolean isBinderAlive() { + return false; + } + + @Override + public IInterface queryLocalInterface(String descriptor) { + return null; + } + + @Override + public String getInterfaceDescriptor() throws RemoteException { + throw new DeadObjectException(); + } + + @Override + public boolean pingBinder() { + return false; + } + + @Override + public void dump(FileDescriptor fd, String[] args) throws RemoteException { + throw new DeadObjectException(); + } + + @Override + public void dumpAsync(FileDescriptor fd, String[] args) throws RemoteException { + throw new DeadObjectException(); + } + + @Override + public boolean transact(int code, Parcel data, Parcel reply, int flags) throws RemoteException { + throw new DeadObjectException(); + } + + @Override + public void linkToDeath(DeathRecipient r, int flags) throws RemoteException { + throw new DeadObjectException(); + } + + @Override + public boolean unlinkToDeath(DeathRecipient deathRecipient, int flags) { + // No need to check whether 'deathRecipient' was ever actually passed to linkToDeath(): Per our + // API contract, if "the IBinder has already died" we never throw and always return false. + return false; + } +} diff --git a/binder/src/testFixtures/java/io/grpc/binder/internal/SettableAsyncSecurityPolicy.java b/binder/src/testFixtures/java/io/grpc/binder/internal/SettableAsyncSecurityPolicy.java new file mode 100644 index 00000000000..2cb22c2fdbf --- /dev/null +++ b/binder/src/testFixtures/java/io/grpc/binder/internal/SettableAsyncSecurityPolicy.java @@ -0,0 +1,83 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Status; +import io.grpc.binder.AsyncSecurityPolicy; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * An {@link AsyncSecurityPolicy} that lets unit tests verify the exact order of authorization + * requests and respond to them one at a time. + */ +public class SettableAsyncSecurityPolicy extends AsyncSecurityPolicy { + private final LinkedBlockingDeque pendingRequests = new LinkedBlockingDeque<>(); + + @Override + public ListenableFuture checkAuthorizationAsync(int uid) { + AuthRequest request = new AuthRequest(uid); + pendingRequests.add(request); + return request.resultFuture; + } + + /** + * Waits for the next "check authorization" request to be made and returns it, throwing in case no + * request arrives in time. + */ + public AuthRequest takeNextAuthRequest(long timeout, TimeUnit unit) + throws InterruptedException, TimeoutException { + AuthRequest nextAuthRequest = pendingRequests.poll(timeout, unit); + if (nextAuthRequest == null) { + throw new TimeoutException(); + } + return nextAuthRequest; + } + + /** Represents a single call to {@link AsyncSecurityPolicy#checkAuthorizationAsync(int)}. */ + public static class AuthRequest { + + /** The argument passed to {@link AsyncSecurityPolicy#checkAuthorizationAsync(int)}. */ + public final int uid; + + private final SettableFuture resultFuture = SettableFuture.create(); + + private AuthRequest(int uid) { + this.uid = uid; + } + + /** Provides this SecurityPolicy's response to this authorization request. */ + public void setResult(Status result) { + checkState(resultFuture.set(result)); + } + + /** Simulates an exceptional response to this authorization request. */ + public void setResult(Throwable t) { + checkState(resultFuture.setException(t)); + } + + /** Tests if the future returned for this authorization request was cancelled by the caller. */ + public boolean isCancelled() { + return resultFuture.isCancelled(); + } + } +} \ No newline at end of file diff --git a/bom/build.gradle b/bom/build.gradle index 1b1f98cff18..f7f3918372f 100644 --- a/bom/build.gradle +++ b/bom/build.gradle @@ -1,40 +1,32 @@ plugins { + id 'java-platform' id "maven-publish" } description = 'gRPC: BOM' +gradle.projectsEvaluated { + def projectsToInclude = rootProject.subprojects.findAll { + return it.name != 'grpc-compiler' + && it.plugins.hasPlugin('java') + && it.plugins.hasPlugin('maven-publish') + && it.tasks.findByName('publishMavenPublicationToMavenRepository')?.enabled + } + dependencies { + constraints { + projectsToInclude.each { api it } + } + } +} + publishing { publications { maven(MavenPublication) { - // remove all other artifacts since BOM doesn't generates any Jar - artifacts = [] - + from components.javaPlatform pom.withXml { - // Generate bom using subprojects - def internalProjects = [ - project.name, - 'grpc-compiler', - ] - - def dependencyManagement = asNode().appendNode('dependencyManagement') - def dependencies = dependencyManagement.appendNode('dependencies') - rootProject.subprojects.each { subproject -> - if (internalProjects.contains(subproject.name)) { - return - } - if (!subproject.hasProperty('publishMavenPublicationToMavenRepository')) { - return - } - if (!subproject.publishMavenPublicationToMavenRepository.enabled) { - return - } - def dependencyNode = dependencies.appendNode('dependency') - dependencyNode.appendNode('groupId', subproject.group) - dependencyNode.appendNode('artifactId', subproject.name) - dependencyNode.appendNode('version', subproject.version) - } + def dependencies = asNode().dependencyManagement.dependencies.last() // add protoc gen (produced by grpc-compiler with different artifact name) + // not sure how to express "pom" in gradle, kept in XML def dependencyNode = dependencies.appendNode('dependency') dependencyNode.appendNode('groupId', project.group) dependencyNode.appendNode('artifactId', 'protoc-gen-grpc-java') diff --git a/build.gradle b/build.gradle index 740f534e136..2cf3439ea76 100644 --- a/build.gradle +++ b/build.gradle @@ -21,13 +21,28 @@ subprojects { apply plugin: "net.ltgt.errorprone" group = "io.grpc" - version = "1.68.0-SNAPSHOT" // CURRENT_GRPC_VERSION + version = "1.81.0-SNAPSHOT" // CURRENT_GRPC_VERSION repositories { maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } - mavenCentral() - mavenLocal() + url = "https://maven-central.storage-download.googleapis.com/maven2/" + metadataSources { + mavenPom() + ignoreGradleMetadataRedirection() + } + } + mavenCentral() { + metadataSources { + mavenPom() + ignoreGradleMetadataRedirection() + } + } + mavenLocal() { + metadataSources { + mavenPom() + ignoreGradleMetadataRedirection() + } + } } tasks.withType(JavaCompile).configureEach { @@ -136,7 +151,7 @@ subprojects { appendToProperty( it.options.errorprone.excludedPaths, ".*/src/generated/[^/]+/java/.*" + - "|.*/build/generated/source/proto/[^/]+/java/.*", + "|.*/build/generated/sources/proto/[^/]+/java/.*", "|") } } @@ -182,6 +197,25 @@ subprojects { } } + plugins.withId("com.android.base") { + android { + lint { + abortOnError true + if (rootProject.hasProperty('failOnWarnings') && rootProject.failOnWarnings.toBoolean()) { + warningsAsErrors true + } + } + } + tasks.withType(JavaCompile).configureEach { + it.options.compilerArgs += [ + "-Xlint:all" + ] + if (rootProject.hasProperty('failOnWarnings') && rootProject.failOnWarnings.toBoolean()) { + it.options.compilerArgs += ["-Werror"] + } + } + } + plugins.withId("java") { dependencies { testImplementation libraries.junit, @@ -223,12 +257,12 @@ subprojects { // At a test failure, log the stack trace to the console so that we don't // have to open the HTML in a browser. - tasks.named("test").configure { + tasks.withType(Test).configureEach { testLogging { exceptionFormat = 'full' - showExceptions true - showCauses true - showStackTraces true + showExceptions = true + showCauses = true + showStackTraces = true } maxHeapSize = '1500m' } @@ -311,7 +345,7 @@ subprojects { } } - plugins.withId("com.github.johnrengelman.shadow") { + plugins.withId("com.gradleup.shadow") { tasks.named("shadowJar").configure { // Do a dance to remove Class-Path. This needs to run after the doFirst() from the // shadow plugin that adds Class-Path and before the core jar action. Using doFirst will @@ -374,11 +408,11 @@ subprojects { url = new File(rootProject.repositoryDir).toURI() } else { String stagingUrl + String baseUrl = "https://ossrh-staging-api.central.sonatype.com/service/local" if (rootProject.hasProperty('repositoryId')) { - stagingUrl = 'https://oss.sonatype.org/service/local/staging/deployByRepositoryId/' + - rootProject.repositoryId + stagingUrl = "${baseUrl}/staging/deployByRepositoryId/" + rootProject.repositoryId } else { - stagingUrl = 'https://oss.sonatype.org/service/local/staging/deploy/maven2/' + stagingUrl = "${baseUrl}/staging/deploy/maven2/" } credentials { if (rootProject.hasProperty('ossrhUsername') && rootProject.hasProperty('ossrhPassword')) { @@ -387,7 +421,7 @@ subprojects { } } def releaseUrl = stagingUrl - def snapshotUrl = 'https://oss.sonatype.org/content/repositories/snapshots/' + def snapshotUrl = 'https://central.sonatype.com/repository/maven-snapshots/' url = version.endsWith('SNAPSHOT') ? snapshotUrl : releaseUrl } } @@ -395,7 +429,7 @@ subprojects { } signing { - required false + required = false sign publishing.publications.maven } @@ -499,4 +533,5 @@ configurations { } } -tasks.register('checkForUpdates', CheckForUpdatesTask, project.configurations.checkForUpdates, "libs") +tasks.register('checkForUpdates', CheckForUpdatesTask, project.configurations.checkForUpdates, "libs", layout.projectDirectory.file("gradle/libs.versions.toml")) + diff --git a/buildSrc/src/main/java/io/grpc/gradle/CheckForUpdatesTask.java b/buildSrc/src/main/java/io/grpc/gradle/CheckForUpdatesTask.java index 9d0156a1b72..b7c28dbbb2d 100644 --- a/buildSrc/src/main/java/io/grpc/gradle/CheckForUpdatesTask.java +++ b/buildSrc/src/main/java/io/grpc/gradle/CheckForUpdatesTask.java @@ -16,11 +16,15 @@ package io.grpc.gradle; +import java.io.IOException; +import java.nio.file.Files; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import javax.inject.Inject; import org.gradle.api.DefaultTask; import org.gradle.api.artifacts.Configuration; @@ -32,6 +36,7 @@ import org.gradle.api.artifacts.result.ResolvedComponentResult; import org.gradle.api.artifacts.result.ResolvedDependencyResult; import org.gradle.api.artifacts.result.UnresolvedDependencyResult; +import org.gradle.api.file.RegularFile; import org.gradle.api.provider.Provider; import org.gradle.api.tasks.Input; import org.gradle.api.tasks.Nested; @@ -45,7 +50,23 @@ public abstract class CheckForUpdatesTask extends DefaultTask { private final Set libraries; @Inject - public CheckForUpdatesTask(Configuration updateConf, String catalog) { + public CheckForUpdatesTask(Configuration updateConf, String catalog, RegularFile commentFile) + throws IOException { + // Check for overrides to the default version selection ('+'), using comments of the form: + // # checkForUpdates: library-name:1.2.+ + List fileComments = Files.lines(commentFile.getAsFile().toPath()) + .filter(l -> l.matches("# *checkForUpdates:.*")) + .map(l -> l.replaceFirst("# *checkForUpdates:", "").trim()) + .collect(Collectors.toList()); + Map aliasToVersionSelector = new HashMap<>(2*fileComments.size()); + for (String comment : fileComments) { + String[] parts = comment.split(":", 2); + String name = parts[0].replaceAll("[_-]", "."); + if (aliasToVersionSelector.put(name, parts[1]) != null) { + throw new RuntimeException("Duplicate checkForUpdates comment for library: " + name); + } + } + updateConf.setVisible(false); updateConf.setTransitive(false); VersionCatalog versionCatalog = getProject().getExtensions().getByType(VersionCatalogsExtension.class).named(catalog); @@ -59,8 +80,12 @@ public CheckForUpdatesTask(Configuration updateConf, String catalog) { oldConf.getDependencies().add(oldDep); Configuration newConf = updateConf.copy(); + String versionSelector = aliasToVersionSelector.remove(name); + if (versionSelector == null) { + versionSelector = "+"; + } Dependency newDep = getProject().getDependencies().create( - depMap(dep.getGroup(), dep.getName(), "+", "pom")); + depMap(dep.getGroup(), dep.getName(), versionSelector, "pom")); newConf.getDependencies().add(newDep); libraries.add(new Library( @@ -68,6 +93,10 @@ public CheckForUpdatesTask(Configuration updateConf, String catalog) { oldConf.getIncoming().getResolutionResult().getRootComponent(), newConf.getIncoming().getResolutionResult().getRootComponent())); } + if (!aliasToVersionSelector.isEmpty()) { + throw new RuntimeException( + "Unused checkForUpdates comments: " + aliasToVersionSelector.keySet()); + } this.libraries = Collections.unmodifiableSet(libraries); } @@ -96,10 +125,16 @@ public void checkForUpdates() { "- Current version of libs.%s not resolved", name)); continue; } + DependencyResult newResult = lib.getNewResult().get().getDependencies().iterator().next(); + if (newResult instanceof UnresolvedDependencyResult) { + System.out.println(String.format( + "- New version of libs.%s not resolved", name)); + continue; + } ModuleVersionIdentifier oldId = ((ResolvedDependencyResult) oldResult).getSelected().getModuleVersion(); - ModuleVersionIdentifier newId = ((ResolvedDependencyResult) lib.getNewResult().get() - .getDependencies().iterator().next()).getSelected().getModuleVersion(); + ModuleVersionIdentifier newId = + ((ResolvedDependencyResult) newResult).getSelected().getModuleVersion(); if (oldId != newId) { System.out.println(String.format( "libs.%s = %s %s -> %s", diff --git a/buildscripts/checkstyle.xml b/buildscripts/checkstyle.xml index 035b1dfc900..0ec8ecc79ce 100644 --- a/buildscripts/checkstyle.xml +++ b/buildscripts/checkstyle.xml @@ -38,6 +38,12 @@ + + + + + + diff --git a/buildscripts/cloudbuild-testing.yaml b/buildscripts/cloudbuild-testing.yaml new file mode 100644 index 00000000000..623b85b6882 --- /dev/null +++ b/buildscripts/cloudbuild-testing.yaml @@ -0,0 +1,64 @@ +substitutions: + _GAE_SERVICE_ACCOUNT: appengine-testing-java@grpc-testing.iam.gserviceaccount.com +options: + env: + - BUILD_ID=$BUILD_ID + - KOKORO_GAE_SERVICE=java-gae-interop-test + - DUMMY_DEFAULT_VERSION=dummy-default + - GRADLE_OPTS=-Dorg.gradle.jvmargs='-Xmx1g' + - GRADLE_FLAGS=-PskipCodegen=true -PskipAndroid=true + logging: CLOUD_LOGGING_ONLY + machineType: E2_HIGHCPU_8 + +steps: +- id: clean-stale-deploys + name: gcr.io/cloud-builders/gcloud + allowFailure: true + script: | + #!/usr/bin/env bash + set -e + echo "Cleaning out stale deploys from previous runs, it is ok if this part fails" + # If the test fails, the deployment is leaked. + # Delete all versions whose name is not 'dummy-default' and is older than 1 hour. + # This expression is an ISO8601 relative date: + # https://cloud.google.com/sdk/gcloud/reference/topic/datetimes + (gcloud app versions list --format="get(version.id)" \ + --filter="service=$KOKORO_GAE_SERVICE AND NOT version : '$DUMMY_DEFAULT_VERSION' AND version.createTime<'-p1h'" \ + | xargs -i gcloud app services delete "$KOKORO_GAE_SERVICE" --version {} --quiet) || true + +- name: gcr.io/cloud-builders/docker + args: ['build', '-t', 'gae-build', 'buildscripts/gae-build/'] + +- id: build + name: gae-build + script: | + #!/usr/bin/env bash + exec ./gradlew $GRADLE_FLAGS :grpc-gae-interop-testing-jdk8:appengineStage + +- id: deploy + name: gcr.io/cloud-builders/gcloud + args: + - app + - deploy + - gae-interop-testing/gae-jdk8/build/staged-app/app.yaml + - --service-account=$_GAE_SERVICE_ACCOUNT + - --no-promote + - --no-stop-previous-version + - --version=cb-$BUILD_ID + +- id: runInteropTestRemote + name: eclipse-temurin:17-jdk + env: + - PROJECT_ID=$PROJECT_ID + script: | + #!/usr/bin/env bash + exec ./gradlew $GRADLE_FLAGS --stacktrace -PgaeDeployVersion="cb-$BUILD_ID" \ + -PgaeProjectId="$PROJECT_ID" :grpc-gae-interop-testing-jdk8:runInteropTestRemote + +- id: cleanup + name: gcr.io/cloud-builders/gcloud + script: | + #!/usr/bin/env bash + set -e + echo "Performing cleanup now." + gcloud app services delete "$KOKORO_GAE_SERVICE" --version "cb-$BUILD_ID" --quiet diff --git a/buildscripts/gae-build/Dockerfile b/buildscripts/gae-build/Dockerfile new file mode 100644 index 00000000000..7e68b270801 --- /dev/null +++ b/buildscripts/gae-build/Dockerfile @@ -0,0 +1,10 @@ +FROM eclipse-temurin:17-jdk + +# The AppEngine Gradle plugin downloads and runs its own gcloud to get the .jar +# to link against, so we need Python even if we use gcloud deploy directly +# instead of using the plugin. +RUN export DEBIAN_FRONTEND=noninteractive && \ + apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y --no-install-recommends python3 && \ + rm -rf /var/lib/apt/lists/* diff --git a/buildscripts/grpc-java-artifacts/Dockerfile b/buildscripts/grpc-java-artifacts/Dockerfile index 736babe9d8e..54c595cd960 100644 --- a/buildscripts/grpc-java-artifacts/Dockerfile +++ b/buildscripts/grpc-java-artifacts/Dockerfile @@ -27,7 +27,11 @@ RUN mkdir -p "$ANDROID_HOME/cmdline-tools" && \ mv "$ANDROID_HOME/cmdline-tools/cmdline-tools" "$ANDROID_HOME/cmdline-tools/latest" && \ yes | "$ANDROID_HOME/cmdline-tools/latest/bin/sdkmanager" --licenses +RUN curl -Ls https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-linux-x86_64.tar.gz | \ + tar xz -C /var/local + # Install Maven -RUN curl -Ls https://dlcdn.apache.org/maven/maven-3/3.8.8/binaries/apache-maven-3.8.8-bin.tar.gz | \ +RUN curl -Ls https://archive.apache.org/dist/maven/maven-3/3.8.8/binaries/apache-maven-3.8.8-bin.tar.gz | \ tar xz -C /var/local -ENV PATH /var/local/apache-maven-3.8.8/bin:$PATH +ENV PATH /var/local/cmake-3.26.3-linux-x86_64/bin:/var/local/apache-maven-3.8.8/bin:$PATH + diff --git a/buildscripts/grpc-java-artifacts/Dockerfile.multiarch.base b/buildscripts/grpc-java-artifacts/Dockerfile.multiarch.base index 8f7cfae2f52..da2c46904ca 100644 --- a/buildscripts/grpc-java-artifacts/Dockerfile.multiarch.base +++ b/buildscripts/grpc-java-artifacts/Dockerfile.multiarch.base @@ -10,5 +10,11 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ g++-aarch64-linux-gnu \ g++-powerpc64le-linux-gnu \ openjdk-8-jdk \ + pkg-config \ && \ rm -rf /var/lib/apt/lists/* + +RUN curl -Ls https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-linux-x86_64.tar.gz | \ + tar xz -C /var/local +ENV PATH /var/local/cmake-3.26.3-linux-x86_64/bin:$PATH + diff --git a/buildscripts/grpc-java-artifacts/Dockerfile.ubuntu2004.base b/buildscripts/grpc-java-artifacts/Dockerfile.ubuntu2004.base index 2d11d76c373..e987fb3e684 100644 --- a/buildscripts/grpc-java-artifacts/Dockerfile.ubuntu2004.base +++ b/buildscripts/grpc-java-artifacts/Dockerfile.ubuntu2004.base @@ -9,5 +9,11 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ curl \ g++-s390x-linux-gnu \ openjdk-8-jdk \ + pkg-config \ && \ rm -rf /var/lib/apt/lists/* + +RUN curl -Ls https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-linux-x86_64.tar.gz | \ + tar xz -C /var/local +ENV PATH /var/local/cmake-3.26.3-linux-x86_64/bin:$PATH + diff --git a/buildscripts/kokoro/android-interop.sh b/buildscripts/kokoro/android-interop.sh index 43bad26f1ec..877311daca5 100755 --- a/buildscripts/kokoro/android-interop.sh +++ b/buildscripts/kokoro/android-interop.sh @@ -2,17 +2,8 @@ set -exu -o pipefail -# Install gRPC and codegen for the Android interop app -# (a composite gradle build can't find protoc-gen-grpc-java) - cd github/grpc-java -export GRADLE_OPTS=-Xmx512m -export LDFLAGS=-L/tmp/protobuf/lib -export CXXFLAGS=-I/tmp/protobuf/include -export LD_LIBRARY_PATH=/tmp/protobuf/lib -export OS_NAME=$(uname) - export ANDROID_HOME=/tmp/Android/Sdk mkdir -p "${ANDROID_HOME}/cmdline-tools" curl -Ls -o cmdline.zip \ @@ -22,15 +13,12 @@ rm cmdline.zip mv "${ANDROID_HOME}/cmdline-tools/cmdline-tools" "${ANDROID_HOME}/cmdline-tools/latest" (yes || true) | "${ANDROID_HOME}/cmdline-tools/latest/bin/sdkmanager" --licenses -# Proto deps -buildscripts/make_dependencies.sh - # Build Android with Java 11, this adds it to the PATH sudo update-java-alternatives --set java-1.11.0-openjdk-amd64 # Unset any existing JAVA_HOME env var to stop Gradle from using it unset JAVA_HOME -GRADLE_FLAGS="-Pandroid.useAndroidX=true" +GRADLE_FLAGS="-Pandroid.useAndroidX=true -Dorg.gradle.jvmargs=-Xmx1024m -PskipCodegen=true" ./gradlew $GRADLE_FLAGS :grpc-android-interop-testing:assembleDebug ./gradlew $GRADLE_FLAGS :grpc-android-interop-testing:assembleDebugAndroidTest diff --git a/buildscripts/kokoro/android.sh b/buildscripts/kokoro/android.sh index 13983e747b7..677825ae66b 100755 --- a/buildscripts/kokoro/android.sh +++ b/buildscripts/kokoro/android.sh @@ -9,9 +9,6 @@ BASE_DIR="$(pwd)" cd "$BASE_DIR/github/grpc-java" -export LDFLAGS=-L/tmp/protobuf/lib -export CXXFLAGS=-I/tmp/protobuf/include -export LD_LIBRARY_PATH=/tmp/protobuf/lib export OS_NAME=$(uname) cat <> gradle.properties @@ -30,10 +27,18 @@ unzip -qd "${ANDROID_HOME}/cmdline-tools" cmdline.zip rm cmdline.zip mv "${ANDROID_HOME}/cmdline-tools/cmdline-tools" "${ANDROID_HOME}/cmdline-tools/latest" (yes || true) | "${ANDROID_HOME}/cmdline-tools/latest/bin/sdkmanager" --licenses - +curl -Ls https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-linux-x86_64.tar.gz | \ + tar xz -C /tmp +export PATH=/tmp/cmake-3.26.3-linux-x86_64/bin:$PATH + # Proto deps buildscripts/make_dependencies.sh +sudo apt-get update && sudo apt-get install pkg-config +export LDFLAGS="$(PKG_CONFIG_PATH=/tmp/protobuf/lib/pkgconfig pkg-config --libs protobuf)" +export CXXFLAGS="$(PKG_CONFIG_PATH=/tmp/protobuf/lib/pkgconfig pkg-config --cflags protobuf)" +export LD_LIBRARY_PATH=/tmp/protobuf/lib + # Build Android with Java 11, this adds it to the PATH sudo update-java-alternatives --set java-1.11.0-openjdk-amd64 # Unset any existing JAVA_HOME env var to stop Gradle from using it @@ -98,6 +103,7 @@ cd $BASE_DIR/github/grpc-java ./gradlew clean git checkout HEAD^ ./gradlew --stop # use a new daemon to build the previous commit +GRADLE_FLAGS="${GRADLE_FLAGS} -PskipCodegen=true" # skip codegen for build from previous commit since it wasn't built with --std=c++14 when making this change ./gradlew publishToMavenLocal $GRADLE_FLAGS cd examples/android/helloworld/ ../../gradlew build $GRADLE_FLAGS @@ -126,15 +132,18 @@ fi # Update the statuses with the deltas +set +x gsutil cp gs://grpc-testing-secrets/github_credentials/oauth_token.txt ~/ desc="New DEX reference count: $(printf "%'d" "$new_dex_count") (delta: $(printf "%'d" "$dex_count_delta"))" +echo "Setting status: $desc" curl -f -s -X POST -H "Content-Type: application/json" \ -H "Authorization: token $(cat ~/oauth_token.txt | tr -d '\n')" \ -d '{"state": "success", "context": "android/dex_diff", "description": "'"${desc}"'"}' \ "https://api.github.com/repos/grpc/grpc-java/statuses/${KOKORO_GITHUB_PULL_REQUEST_COMMIT}" desc="New APK size in bytes: $(printf "%'d" "$new_apk_size") (delta: $(printf "%'d" "$apk_size_delta"))" +echo "Setting status: $desc" curl -f -s -X POST -H "Content-Type: application/json" \ -H "Authorization: token $(cat ~/oauth_token.txt | tr -d '\n')" \ -d '{"state": "success", "context": "android/apk_diff", "description": "'"${desc}"'"}' \ diff --git a/buildscripts/kokoro/gae-interop.sh b/buildscripts/kokoro/gae-interop.sh deleted file mode 100755 index c4ce56cac52..00000000000 --- a/buildscripts/kokoro/gae-interop.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash - -set -exu -o pipefail -if [[ -f /VERSION ]]; then - cat /VERSION -fi - -KOKORO_GAE_SERVICE="java-gae-interop-test" - -# We deploy as different versions of a single service, this way any stale -# lingering deploys can be easily cleaned up by purging all running versions -# of this service. -KOKORO_GAE_APP_VERSION=$(hostname) - -# A dummy version that can be the recipient of all traffic, so that the kokoro test version can be -# set to 0 traffic. This is a requirement in order to delete it. -DUMMY_DEFAULT_VERSION='dummy-default' - -function cleanup() { - echo "Performing cleanup now." - gcloud app services delete $KOKORO_GAE_SERVICE --version $KOKORO_GAE_APP_VERSION --quiet -} -trap cleanup SIGHUP SIGINT SIGTERM EXIT - -readonly GRPC_JAVA_DIR="$(cd "$(dirname "$0")"/../.. && pwd)" -cd "$GRPC_JAVA_DIR" - -## -## Deploy the dummy 'default' version of the service -## -GRADLE_FLAGS="--stacktrace -DgaeStopPreviousVersion=false -PskipCodegen=true -PskipAndroid=true" -export GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" - -# Deploy the dummy 'default' version. We only require that it exists when cleanup() is called. -# It ok if we race with another run and fail here, because the end result is idempotent. -set +e -if ! gcloud app versions describe "$DUMMY_DEFAULT_VERSION" --service="$KOKORO_GAE_SERVICE"; then - ./gradlew $GRADLE_FLAGS -DgaeDeployVersion="$DUMMY_DEFAULT_VERSION" -DgaePromote=true :grpc-gae-interop-testing-jdk8:appengineDeploy -else - echo "default version already exists: $DUMMY_DEFAULT_VERSION" -fi -set -e - -# Deploy and test the real app (jdk8) -./gradlew $GRADLE_FLAGS -DgaeDeployVersion="$KOKORO_GAE_APP_VERSION" :grpc-gae-interop-testing-jdk8:runInteropTestRemote - -set +e -echo "Cleaning out stale deploys from previous runs, it is ok if this part fails" - -# Sometimes the trap based cleanup fails. -# Delete all versions whose name is not 'dummy-default' and is older than 1 hour. -# This expression is an ISO8601 relative date: -# https://cloud.google.com/sdk/gcloud/reference/topic/datetimes -gcloud app versions list --format="get(version.id)" --filter="service=$KOKORO_GAE_SERVICE AND NOT version : 'dummy-default' AND version.createTime<'-p1h'" | xargs -i gcloud app services delete "$KOKORO_GAE_SERVICE" --version {} --quiet -exit 0 diff --git a/buildscripts/kokoro/linux_aarch64.cfg b/buildscripts/kokoro/linux_aarch64.cfg deleted file mode 100644 index 325d910c5ea..00000000000 --- a/buildscripts/kokoro/linux_aarch64.cfg +++ /dev/null @@ -1,13 +0,0 @@ -# Config file for internal CI - -# Location of the continuous shell script in repository. -build_file: "grpc-java/buildscripts/kokoro/linux_aarch64.sh" -timeout_mins: 60 - -action { - define_artifacts { - regex: "github/grpc-java/**/build/test-results/**/sponge_log.xml" - regex: "github/grpc-java/mvn-artifacts/**" - regex: "github/grpc-java/artifacts/**" - } -} diff --git a/buildscripts/kokoro/linux_aarch64.sh b/buildscripts/kokoro/linux_aarch64.sh deleted file mode 100755 index f4a1292efb5..00000000000 --- a/buildscripts/kokoro/linux_aarch64.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash -set -veux -o pipefail - -if [[ -f /VERSION ]]; then - cat /VERSION -fi - -readonly GRPC_JAVA_DIR="$(cd "$(dirname "$0")"/../.. && pwd)" - -. "$GRPC_JAVA_DIR"/buildscripts/kokoro/kokoro.sh -trap spongify_logs EXIT - -cd github/grpc-java - -buildscripts/qemu_helpers/prepare_qemu.sh - -buildscripts/run_arm64_tests_in_docker.sh diff --git a/buildscripts/kokoro/macos.cfg b/buildscripts/kokoro/macos.cfg index a58691a7102..4c79743692e 100644 --- a/buildscripts/kokoro/macos.cfg +++ b/buildscripts/kokoro/macos.cfg @@ -2,7 +2,7 @@ # Location of the continuous shell script in repository. build_file: "grpc-java/buildscripts/kokoro/macos.sh" -timeout_mins: 45 +timeout_mins: 60 # We always build mvn artifacts. action { diff --git a/buildscripts/kokoro/macos.sh b/buildscripts/kokoro/macos.sh index 97259231ee8..0240c0650f7 100755 --- a/buildscripts/kokoro/macos.sh +++ b/buildscripts/kokoro/macos.sh @@ -1,5 +1,6 @@ #!/bin/bash set -veux -o pipefail +CMAKE_VERSION=3.31.10 if [[ -f /VERSION ]]; then cat /VERSION @@ -7,6 +8,10 @@ fi readonly GRPC_JAVA_DIR="$(cd "$(dirname "$0")"/../.. && pwd)" +DOWNLOAD_DIR=/tmp/source +mkdir -p ${DOWNLOAD_DIR} +curl -Ls https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-macos-universal.tar.gz | tar xz -C ${DOWNLOAD_DIR} + # We had problems with random tests timing out because it took seconds to do # trivial (ns) operations. The Kokoro Mac machines have 2 cores with 4 logical # threads, so Gradle should be using 4 workers by default. @@ -15,4 +20,9 @@ export GRADLE_FLAGS="${GRADLE_FLAGS:-} --max-workers=2" . "$GRPC_JAVA_DIR"/buildscripts/kokoro/kokoro.sh trap spongify_logs EXIT +brew install --cask temurin@8 +export PATH="$(/usr/libexec/java_home -v"1.8.0")/bin:${DOWNLOAD_DIR}/cmake-${CMAKE_VERSION}-macos-universal/CMake.app/Contents/bin:${PATH}" +export JAVA_HOME="$(/usr/libexec/java_home -v"1.8.0")" +brew install maven + "$GRPC_JAVA_DIR"/buildscripts/kokoro/unix.sh diff --git a/buildscripts/kokoro/psm-cloud-run.cfg b/buildscripts/kokoro/psm-cloud-run.cfg new file mode 100644 index 00000000000..1f2d6da208f --- /dev/null +++ b/buildscripts/kokoro/psm-cloud-run.cfg @@ -0,0 +1,17 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-java/buildscripts/kokoro/psm-interop-test-java.sh" +timeout_mins: 240 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*.log" + strip_prefix: "artifacts" + } +} +env_vars { + key: "PSM_TEST_SUITE" + value: "cloud_run" +} diff --git a/buildscripts/kokoro/psm-dualstack.cfg b/buildscripts/kokoro/psm-dualstack.cfg index 55c906bc4ec..a55d91a95b0 100644 --- a/buildscripts/kokoro/psm-dualstack.cfg +++ b/buildscripts/kokoro/psm-dualstack.cfg @@ -2,7 +2,7 @@ # Location of the continuous shell script in repository. build_file: "grpc-java/buildscripts/kokoro/psm-interop-test-java.sh" -timeout_mins: 120 +timeout_mins: 240 action { define_artifacts { diff --git a/buildscripts/kokoro/psm-light.cfg b/buildscripts/kokoro/psm-light.cfg new file mode 100644 index 00000000000..decd179efa3 --- /dev/null +++ b/buildscripts/kokoro/psm-light.cfg @@ -0,0 +1,17 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-java/buildscripts/kokoro/psm-interop-test-java.sh" +timeout_mins: 120 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*.log" + strip_prefix: "artifacts" + } +} +env_vars { + key: "PSM_TEST_SUITE" + value: "light" +} diff --git a/buildscripts/kokoro/psm-spiffe.cfg b/buildscripts/kokoro/psm-spiffe.cfg new file mode 100644 index 00000000000..b04d715fca1 --- /dev/null +++ b/buildscripts/kokoro/psm-spiffe.cfg @@ -0,0 +1,17 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-java/buildscripts/kokoro/psm-interop-test-java.sh" +timeout_mins: 240 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*.log" + strip_prefix: "artifacts" + } +} +env_vars { + key: "PSM_TEST_SUITE" + value: "spiffe" +} diff --git a/buildscripts/kokoro/unix.sh b/buildscripts/kokoro/unix.sh index 9b1a4054c7e..693768a0270 100755 --- a/buildscripts/kokoro/unix.sh +++ b/buildscripts/kokoro/unix.sh @@ -23,11 +23,6 @@ readonly GRPC_JAVA_DIR="$(cd "$(dirname "$0")"/../.. && pwd)" # cd to the root dir of grpc-java cd $(dirname $0)/../.. -# TODO(zpencer): always make sure we are using Oracle jdk8 -if [[ -f /usr/libexec/java_home ]]; then - JAVA_HOME=$(/usr/libexec/java_home -v"1.8.0") -fi - # ARCH is x86_64 unless otherwise specified. ARCH="${ARCH:-x86_64}" @@ -43,7 +38,13 @@ ARCH="$ARCH" buildscripts/make_dependencies.sh # Set properties via flags, do not pollute gradle.properties GRADLE_FLAGS="${GRADLE_FLAGS:-}" +GRADLE_FLAGS+=" --stacktrace" GRADLE_FLAGS+=" -PtargetArch=$ARCH" + +# For universal binaries on macOS, signal Gradle to use universal flags. +if [[ "$(uname -s)" == "Darwin" ]]; then + GRADLE_FLAGS+=" -PbuildUniversal=true" +fi GRADLE_FLAGS+=" -Pcheckstyle.ignoreFailures=false" GRADLE_FLAGS+=" -PfailOnWarnings=true" GRADLE_FLAGS+=" -PerrorProne=true" @@ -56,9 +57,9 @@ fi export GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" # Make protobuf discoverable by :grpc-compiler -export LD_LIBRARY_PATH=/tmp/protobuf/lib -export LDFLAGS=-L/tmp/protobuf/lib -export CXXFLAGS="-I/tmp/protobuf/include" +export LDFLAGS="$(PKG_CONFIG_PATH=/tmp/protobuf/lib/pkgconfig pkg-config --libs protobuf)" +export CXXFLAGS="$(PKG_CONFIG_PATH=/tmp/protobuf/lib/pkgconfig pkg-config --cflags protobuf)" +export LIBRARY_PATH=/tmp/protobuf/lib ./gradlew grpc-compiler:clean $GRADLE_FLAGS diff --git a/buildscripts/kokoro/windows.cfg b/buildscripts/kokoro/windows.cfg index bdfaa38904f..ec0a3c9ae34 100644 --- a/buildscripts/kokoro/windows.cfg +++ b/buildscripts/kokoro/windows.cfg @@ -2,7 +2,7 @@ # Location of the continuous shell script in repository. build_file: "grpc-java/buildscripts/kokoro/windows.bat" -timeout_mins: 45 +timeout_mins: 90 # We always build mvn artifacts. action { diff --git a/buildscripts/kokoro/windows32.bat b/buildscripts/kokoro/windows32.bat index ffd4d3b99a6..d51beba82f9 100644 --- a/buildscripts/kokoro/windows32.bat +++ b/buildscripts/kokoro/windows32.bat @@ -15,19 +15,21 @@ set ESCWORKSPACE=%WORKSPACE:\=\\% @rem Clear JAVA_HOME to prevent a different Java version from being used set JAVA_HOME= -set PATH=C:\Program Files\OpenJDK\openjdk-11.0.12_7\bin;%PATH% mkdir grpc-java-helper32 cd grpc-java-helper32 -call "%VS140COMNTOOLS%\vsvars32.bat" || exit /b 1 +call "%VS170COMNTOOLS%\..\..\VC\Auxiliary\Build\vcvars32.bat" || exit /b 1 call "%WORKSPACE%\buildscripts\make_dependencies.bat" || exit /b 1 cd "%WORKSPACE%" SET TARGET_ARCH=x86_32 SET FAIL_ON_WARNINGS=true -SET VC_PROTOBUF_LIBS=%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\build\\Release -SET VC_PROTOBUF_INCLUDE=%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\build\\include +SET PROTOBUF_VER=33.4 +SET PKG_CONFIG_PATH=%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\build\\protobuf-%PROTOBUF_VER%\\lib\\pkgconfig +SET VC_PROTOBUF_LIBS=/LIBPATH:%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\build\\protobuf-%PROTOBUF_VER%\\lib +SET VC_PROTOBUF_INCLUDE=%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\build\\protobuf-%PROTOBUF_VER%\\include +call :Get_Libs SET GRADLE_FLAGS=-PtargetArch=%TARGET_ARCH% -PfailOnWarnings=%FAIL_ON_WARNINGS% -PvcProtobufLibs=%VC_PROTOBUF_LIBS% -PvcProtobufInclude=%VC_PROTOBUF_INCLUDE% -PskipAndroid=true SET GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" @@ -50,3 +52,34 @@ IF NOT %GRADLEEXIT% == 0 ( cmd.exe /C "%WORKSPACE%\gradlew.bat --stop" cmd.exe /C "%WORKSPACE%\gradlew.bat %GRADLE_FLAGS% -Dorg.gradle.parallel=false -PrepositoryDir=%WORKSPACE%\artifacts clean grpc-compiler:build grpc-compiler:publish" || exit /b 1 + +goto :eof +:Get_Libs +SetLocal EnableDelayedExpansion +set "libs_list=" +for /f "tokens=*" %%a in ('pkg-config --libs protobuf') do ( + for %%b in (%%a) do ( + set lib=%%b + set libfirst2char=!lib:~0,2! + if !libfirst2char!==-l ( + @rem remove the leading -l + set lib=!lib:~2! + @rem remove spaces + set lib=!lib: =! + set libprefix=!lib:~0,4! + if !libprefix!==absl ( + set lib=!lib!.lib + ) else ( + set lib=lib!lib!.lib + ) + if "!libs_list!"=="" ( + set libs_list=!lib! + ) else ( + set libs_list=!libs_list!,!lib! + ) + ) + ) +) +EndLocal & set "VC_PROTOBUF_LIBS=%VC_PROTOBUF_LIBS%,%libs_list%" +exit /b 0 + diff --git a/buildscripts/kokoro/windows64.bat b/buildscripts/kokoro/windows64.bat index 8542f1c0536..180025d5e82 100644 --- a/buildscripts/kokoro/windows64.bat +++ b/buildscripts/kokoro/windows64.bat @@ -14,19 +14,21 @@ set ESCWORKSPACE=%WORKSPACE:\=\\% @rem Clear JAVA_HOME to prevent a different Java version from being used set JAVA_HOME= -set PATH=C:\Program Files\OpenJDK\openjdk-11.0.12_7\bin;%PATH% mkdir grpc-java-helper64 cd grpc-java-helper64 -call "%VS140COMNTOOLS%\..\..\VC\bin\amd64\vcvars64.bat" || exit /b 1 +call "%VS170COMNTOOLS%\..\..\VC\Auxiliary\Build\vcvars64.bat" || exit /b 1 call "%WORKSPACE%\buildscripts\make_dependencies.bat" || exit /b 1 cd "%WORKSPACE%" SET TARGET_ARCH=x86_64 SET FAIL_ON_WARNINGS=true -SET VC_PROTOBUF_LIBS=%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\build\\Release -SET VC_PROTOBUF_INCLUDE=%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\build\\include +SET PROTOBUF_VER=33.4 +SET PKG_CONFIG_PATH=%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\build\\protobuf-%PROTOBUF_VER%\\lib\\pkgconfig +SET VC_PROTOBUF_LIBS=/LIBPATH:%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\build\\protobuf-%PROTOBUF_VER%\\lib +SET VC_PROTOBUF_INCLUDE=%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\build\\protobuf-%PROTOBUF_VER%\\include +call :Get_Libs SET GRADLE_FLAGS=-PtargetArch=%TARGET_ARCH% -PfailOnWarnings=%FAIL_ON_WARNINGS% -PvcProtobufLibs=%VC_PROTOBUF_LIBS% -PvcProtobufInclude=%VC_PROTOBUF_INCLUDE% -PskipAndroid=true SET GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" @@ -34,3 +36,34 @@ SET GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" cmd.exe /C "%WORKSPACE%\gradlew.bat --stop" cmd.exe /C "%WORKSPACE%\gradlew.bat %GRADLE_FLAGS% -Dorg.gradle.parallel=false -PrepositoryDir=%WORKSPACE%\artifacts grpc-compiler:clean grpc-compiler:build grpc-compiler:publish" || exit /b 1 + +goto :eof +:Get_Libs +SetLocal EnableDelayedExpansion +set "libs_list=" +for /f "tokens=*" %%a in ('pkg-config --libs protobuf') do ( + for %%b in (%%a) do ( + set lib=%%b + set libfirst2char=!lib:~0,2! + if !libfirst2char!==-l ( + @rem remove the leading -l + set lib=!lib:~2! + @rem remove spaces + set lib=!lib: =! + set libprefix=!lib:~0,4! + if !libprefix!==absl ( + set lib=!lib!.lib + ) else ( + set lib=lib!lib!.lib + ) + if "!libs_list!"=="" ( + set libs_list=!lib! + ) else ( + set libs_list=!libs_list!,!lib! + ) + ) + ) +) +EndLocal & set "VC_PROTOBUF_LIBS=%VC_PROTOBUF_LIBS%,%libs_list%" +exit /b 0 + diff --git a/buildscripts/make_dependencies.bat b/buildscripts/make_dependencies.bat index 2bbfd394d46..a11f84d998e 100644 --- a/buildscripts/make_dependencies.bat +++ b/buildscripts/make_dependencies.bat @@ -1,12 +1,16 @@ -set PROTOBUF_VER=21.7 -set CMAKE_NAME=cmake-3.3.2-win32-x86 +choco install -y pkgconfiglite +choco install -y openjdk --version=17.0 +set PATH=%PATH%;"c:\Program Files\OpenJDK\jdk-17\bin" +set PROTOBUF_VER=33.4 +set ABSL_VERSION=20250127.1 +set CMAKE_NAME=cmake-3.26.3-windows-x86_64 if not exist "protobuf-%PROTOBUF_VER%\build\Release\" ( call :installProto || exit /b 1 ) echo Compile gRPC-Java with something like: -echo -PtargetArch=x86_32 -PvcProtobufLibs=%cd%\protobuf-%PROTOBUF_VER%\build\Release -PvcProtobufInclude=%cd%\protobuf-%PROTOBUF_VER%\build\include +echo -PtargetArch=x86_32 -PvcProtobufLibPath=%cd%\protobuf-%PROTOBUF_VER%\build\protobuf-%PROTOBUF_VER%\lib -PvcProtobufInclude=%cd%\protobuf-%PROTOBUF_VER%\build\protobuf-%PROTOBUF_VER%\include -PvcProtobufLibs=insert-list-of-libs-from-pkg-config-output-here goto :eof @@ -20,25 +24,34 @@ if not exist "%CMAKE_NAME%" ( set PATH=%PATH%;%cd%\%CMAKE_NAME%\bin :hasCmake @rem GitHub requires TLSv1.2, and for whatever reason our powershell doesn't have it enabled -powershell -command "$ErrorActionPreference = 'stop'; & { [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12 ; iwr https://github.com/google/protobuf/archive/v%PROTOBUF_VER%.zip -OutFile protobuf.zip }" || exit /b 1 +powershell -command "$ProgressPreference = 'SilentlyContinue'; $ErrorActionPreference = 'stop'; & { [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12 ; iwr https://github.com/google/protobuf/releases/download/v%PROTOBUF_VER%/protobuf-%PROTOBUF_VER%.zip -OutFile protobuf.zip }" || exit /b 1 powershell -command "$ErrorActionPreference = 'stop'; & { Add-Type -AssemblyName System.IO.Compression.FileSystem; [System.IO.Compression.ZipFile]::ExtractToDirectory('protobuf.zip', '.') }" || exit /b 1 del protobuf.zip +powershell -command "$ProgressPreference = 'SilentlyContinue'; $ErrorActionPreference = 'stop'; & { [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12 ; iwr https://github.com/abseil/abseil-cpp/archive/refs/tags/%ABSL_VERSION%.zip -OutFile absl.zip }" || exit /b 1 +powershell -command "$ErrorActionPreference = 'stop'; & { Add-Type -AssemblyName System.IO.Compression.FileSystem; [System.IO.Compression.ZipFile]::ExtractToDirectory('absl.zip', '.') }" || exit /b 1 +del absl.zip +move abseil-cpp-%ABSL_VERSION% protobuf-%PROTOBUF_VER%\third_party\abseil-cpp mkdir protobuf-%PROTOBUF_VER%\build pushd protobuf-%PROTOBUF_VER%\build -@rem Workaround https://github.com/protocolbuffers/protobuf/issues/10174 -powershell -command "(Get-Content ..\cmake\extract_includes.bat.in) -replace '\.\.\\', '' | Out-File -encoding ascii ..\cmake\extract_includes.bat.in" @rem cmake does not detect x86_64 from the vcvars64.bat variables. -@rem If vcvars64.bat has set PLATFORM to X64, then inform cmake to use the Win64 version of VS -if "%PLATFORM%" == "X64" ( - @rem Note the space - SET CMAKE_VSARCH= Win64 +@rem If vcvars64.bat has set PLATFORM to X64, then inform cmake to use the Win64 version of VS, likewise for x32 +if "%PLATFORM%" == "x64" ( + SET CMAKE_VSARCH=-A x64 +) else if "%PLATFORM%" == "x86" ( + @rem -A x86 doesn't work: https://github.com/microsoft/vcpkg/issues/15465 + SET CMAKE_VSARCH=-DCMAKE_GENERATOR_PLATFORM=WIN32 ) else ( SET CMAKE_VSARCH= ) -cmake -Dprotobuf_BUILD_TESTS=OFF -G "Visual Studio %VisualStudioVersion:~0,2%%CMAKE_VSARCH%" .. || exit /b 1 -msbuild /maxcpucount /p:Configuration=Release /verbosity:minimal libprotoc.vcxproj || exit /b 1 -call extract_includes.bat || exit /b 1 +for /f "tokens=4 delims=\" %%a in ("%VCINSTALLDIR%") do ( + SET VC_YEAR=%%a +) +for /f "tokens=1 delims=." %%a in ("%VisualStudioVersion%") do ( + SET visual_studio_major_version=%%a +) +cmake -DCMAKE_CXX_STANDARD=17 -DABSL_MSVC_STATIC_RUNTIME=ON -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_INSTALL_PREFIX=%cd%\protobuf-%PROTOBUF_VER% -DCMAKE_PREFIX_PATH=%cd%\protobuf-%PROTOBUF_VER% -G "Visual Studio %visual_studio_major_version% %VC_YEAR%" %CMAKE_VSARCH% .. || exit /b 1 +cmake --build . --config Release --target install || exit /b 1 popd goto :eof @@ -49,3 +62,4 @@ powershell -command "$ErrorActionPreference = 'stop'; & { iwr https://cmake.org/ powershell -command "$ErrorActionPreference = 'stop'; & { Add-Type -AssemblyName System.IO.Compression.FileSystem; [System.IO.Compression.ZipFile]::ExtractToDirectory('cmake.zip', '.') }" || exit /b 1 del cmake.zip goto :eof + diff --git a/buildscripts/make_dependencies.sh b/buildscripts/make_dependencies.sh index 3d02a72f4eb..8cbefddd2eb 100755 --- a/buildscripts/make_dependencies.sh +++ b/buildscripts/make_dependencies.sh @@ -3,13 +3,63 @@ # Build protoc set -evux -o pipefail -PROTOBUF_VERSION=21.7 +PROTOBUF_VERSION=33.4 +ABSL_VERSION=20250127.1 # ARCH is x86_64 bit unless otherwise specified. ARCH="${ARCH:-x86_64}" DOWNLOAD_DIR=/tmp/source INSTALL_DIR="/tmp/protobuf-cache/$PROTOBUF_VERSION/$(uname -s)-$ARCH" +BUILDSCRIPTS_DIR="$(cd "$(dirname "$0")" && pwd)" + +function build_and_install() { + if [[ "$1" == "abseil" ]]; then + TESTS_OFF_ARG=ABSL_BUILD_TEST_HELPERS + else + TESTS_OFF_ARG=protobuf_BUILD_TESTS + fi + if [[ "$(uname -s)" == "Darwin" ]]; then + cmake .. \ + -DCMAKE_CXX_STANDARD=17 -D${TESTS_OFF_ARG}=OFF -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_INSTALL_PREFIX="$INSTALL_DIR" \ + -DCMAKE_PREFIX_PATH="$INSTALL_DIR" \ + -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \ + -B. || exit 1 + elif [[ "$ARCH" == x86* ]]; then + CFLAGS=-m${ARCH#*_} CXXFLAGS=-m${ARCH#*_} cmake .. \ + -DCMAKE_CXX_STANDARD=17 -D${TESTS_OFF_ARG}=OFF -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_INSTALL_PREFIX="$INSTALL_DIR" \ + -DCMAKE_PREFIX_PATH="$INSTALL_DIR" \ + -B. || exit 1 + else + if [[ "$ARCH" == aarch_64 ]]; then + GCC_ARCH=aarch64-linux-gnu + elif [[ "$ARCH" == ppcle_64 ]]; then + GCC_ARCH=powerpc64le-linux-gnu + elif [[ "$ARCH" == s390_64 ]]; then + GCC_ARCH=s390x-linux-gnu + elif [[ "$ARCH" == loongarch_64 ]]; then + GCC_ARCH=loongarch64-unknown-linux-gnu + else + echo "Unknown architecture: $ARCH" + exit 1 + fi + cmake .. \ + -DCMAKE_CXX_STANDARD=17 -D${TESTS_OFF_ARG}=OFF -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_INSTALL_PREFIX="$INSTALL_DIR" \ + -DCMAKE_PREFIX_PATH="$INSTALL_DIR" \ + -Dcrosscompile_ARCH="$GCC_ARCH" \ + -DCMAKE_TOOLCHAIN_FILE=$BUILDSCRIPTS_DIR/toolchain.cmake \ + -B. || exit 1 + fi + export CMAKE_BUILD_PARALLEL_LEVEL="$NUM_CPU" + cmake --build . || exit 1 + # install here so we don't need sudo + cmake --install . || exit 1 +} + mkdir -p $DOWNLOAD_DIR +cd "$DOWNLOAD_DIR" # Start with a sane default NUM_CPU=4 @@ -19,6 +69,7 @@ fi if [[ $(uname) == 'Darwin' ]]; then NUM_CPU=$(sysctl -n hw.ncpu) fi +export CMAKE_BUILD_PARALLEL_LEVEL="$NUM_CPU" # Make protoc # Can't check for presence of directory as cache auto-creates it. @@ -26,28 +77,24 @@ if [ -f ${INSTALL_DIR}/bin/protoc ]; then echo "Not building protobuf. Already built" # TODO(ejona): swap to `brew install --devel protobuf` once it is up-to-date else - if [[ ! -d "$DOWNLOAD_DIR"/protobuf-"${PROTOBUF_VERSION}" ]]; then - curl -Ls https://github.com/google/protobuf/releases/download/v${PROTOBUF_VERSION}/protobuf-all-${PROTOBUF_VERSION}.tar.gz | tar xz -C $DOWNLOAD_DIR - fi - pushd $DOWNLOAD_DIR/protobuf-${PROTOBUF_VERSION} - # install here so we don't need sudo - if [[ "$ARCH" == x86* ]]; then - ./configure CFLAGS=-m${ARCH#*_} CXXFLAGS=-m${ARCH#*_} --disable-shared \ - --prefix="$INSTALL_DIR" - elif [[ "$ARCH" == aarch* ]]; then - ./configure --disable-shared --host=aarch64-linux-gnu --prefix="$INSTALL_DIR" - elif [[ "$ARCH" == ppc* ]]; then - ./configure --disable-shared --host=powerpc64le-linux-gnu --prefix="$INSTALL_DIR" - elif [[ "$ARCH" == s390* ]]; then - ./configure --disable-shared --host=s390x-linux-gnu --prefix="$INSTALL_DIR" - elif [[ "$ARCH" == loongarch* ]]; then - ./configure --disable-shared --host=loongarch64-unknown-linux-gnu --prefix="$INSTALL_DIR" + if [[ ! -d "protobuf-${PROTOBUF_VERSION}" ]]; then + curl -Ls "https://github.com/google/protobuf/releases/download/v${PROTOBUF_VERSION}/protobuf-${PROTOBUF_VERSION}.tar.gz" | tar xz + curl -Ls "https://github.com/abseil/abseil-cpp/archive/refs/tags/${ABSL_VERSION}.tar.gz" | tar xz fi # the same source dir is used for 32 and 64 bit builds, so we need to clean stale data first - make clean - make V=0 -j$NUM_CPU - make install + rm -rf "$DOWNLOAD_DIR/abseil-cpp-${ABSL_VERSION}/build" + mkdir "$DOWNLOAD_DIR/abseil-cpp-${ABSL_VERSION}/build" + pushd "$DOWNLOAD_DIR/abseil-cpp-${ABSL_VERSION}/build" + build_and_install "abseil" + popd + + rm -rf "$DOWNLOAD_DIR/protobuf-${PROTOBUF_VERSION}/build" + mkdir "$DOWNLOAD_DIR/protobuf-${PROTOBUF_VERSION}/build" + pushd "$DOWNLOAD_DIR/protobuf-${PROTOBUF_VERSION}/build" + build_and_install "protobuf" popd + + [ -d "$INSTALL_DIR/lib64" ] && mv "$INSTALL_DIR/lib64" "$INSTALL_DIR/lib" fi # If /tmp/protobuf exists then we just assume it's a symlink created by us. @@ -60,7 +107,9 @@ ln -s "$INSTALL_DIR" /tmp/protobuf cat <> "${grpc_java_dir}/gradle.properties" -skipAndroid=true -skipCodegen=true -org.gradle.parallel=true -org.gradle.jvmargs=-Xmx1024m -EOF - -export JAVA_OPTS="-Duser.home=/grpc-java/.current-user-home -Djava.util.prefs.userRoot=/grpc-java/.current-user-home/.java/.userPrefs" - -# build under x64 docker image to save time over building everything under -# aarch64 emulator. We've already built and tested the protoc binaries -# so for the rest of the build we will be using "-PskipCodegen=true" -# avoid further complicating the build. -docker run $DOCKER_ARGS --rm=true -v "${grpc_java_dir}":/grpc-java -w /grpc-java \ - --user "$(id -u):$(id -g)" -e JAVA_OPTS \ - openjdk:11-jdk-slim-buster \ - ./gradlew build -x test - -# Build and run java tests under aarch64 image. -# To be able to run this docker container on x64 machine, one needs to have -# qemu-user-static properly registered with binfmt_misc. -# The most important flag binfmt_misc flag we need is "F" (set by "--persistent yes"), -# which allows the qemu-aarch64-static binary to be loaded eagerly at the time of registration with binfmt_misc. -# That way, we can emulate aarch64 binaries running inside docker containers transparently, without needing the emulator -# binary to be accessible from the docker image we're emulating. -# Note that on newer distributions (such as glinux), simply "apt install qemu-user-static" is sufficient -# to install qemu-user-static with the right flags. -# A note on the "docker run" args used: -# - run docker container under current user's UID to avoid polluting the workspace -# - set the user.home property to avoid creating a "?" directory under grpc-java -docker run $DOCKER_ARGS --rm=true -v "${grpc_java_dir}":/grpc-java -w /grpc-java \ - --user "$(id -u):$(id -g)" -e JAVA_OPTS \ - arm64v8/openjdk:11-jdk-slim-buster \ - ./gradlew build diff --git a/buildscripts/sonatype-upload.sh b/buildscripts/sonatype-upload.sh index 16637149126..4baa4e46ca0 100755 --- a/buildscripts/sonatype-upload.sh +++ b/buildscripts/sonatype-upload.sh @@ -59,7 +59,7 @@ if [ -z "$USERNAME" -o -z "$PASSWORD" ]; then exit 1 fi -STAGING_URL="https://oss.sonatype.org/service/local/staging" +STAGING_URL="https://ossrh-staging-api.central.sonatype.com/service/local/staging" # We go through the effort of using deloyByRepositoryId/ because it is # _substantially_ faster to upload files than deploy/maven2/. When using @@ -108,3 +108,18 @@ XML=" " curl --fail-with-body -X POST -d "$XML" -u "$USERPASS" -H "Content-Type: application/xml" \ "$STAGING_URL/profiles/$PROFILE_ID/finish" + +# TODO (okshiva): After 2-3 releases make it automatic. +# After closing the repository on the staging API, we must manually trigger +# its upload to the main Central Publisher Portal. We set publishing_type=automatic +# to have it release automatically upon passing validation. +# echo "Triggering release of repository ${REPOID} to the Central Portal" + +# MANUAL_API_URL="https://ossrh-staging-api.central.sonatype.com/service/local/manual" + +#curl --fail-with-body -X POST \ +# -H "Authorization: Bearer ${USERPASS}" \ +# -H "Content-Type: application/json" \ +# "${MANUAL_API_URL}/upload/repository/${REPOID}?publishing_type=automatic" + +# echo "Release triggered. Monitor progress at https://central.sonatype.com/publishing/deployments" diff --git a/buildscripts/toolchain.cmake b/buildscripts/toolchain.cmake new file mode 100644 index 00000000000..b71515cebda --- /dev/null +++ b/buildscripts/toolchain.cmake @@ -0,0 +1,9 @@ +set(CMAKE_SYSTEM_NAME Linux) + +set(CMAKE_C_COMPILER "${crosscompile_ARCH}-gcc") +set(CMAKE_CXX_COMPILER "${crosscompile_ARCH}-g++") +set(CMAKE_FIND_ROOT_PATH "/usr/${crosscompile_ARCH}/") + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) diff --git a/census/BUILD.bazel b/census/BUILD.bazel index aec16c46af0..f017eeaf8bd 100644 --- a/census/BUILD.bazel +++ b/census/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( @@ -10,6 +11,7 @@ java_library( "//api", "//context", artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), artifact("io.opencensus:opencensus-api"), artifact("io.opencensus:opencensus-contrib-grpc-metrics"), diff --git a/census/build.gradle b/census/build.gradle index c1dc53e4c05..c7cb02c15a0 100644 --- a/census/build.gradle +++ b/census/build.gradle @@ -27,12 +27,20 @@ dependencies { project(':grpc-testing'), libraries.opencensus.impl - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("javadoc").configure { - failOnError false // no public or protected classes found to document + failOnError = false // no public or protected classes found to document exclude 'io/grpc/census/internal/**' exclude 'io/grpc/census/Internal*' } diff --git a/census/src/main/java/io/grpc/census/CensusStatsModule.java b/census/src/main/java/io/grpc/census/CensusStatsModule.java index ad16bef9604..8f571ceb627 100644 --- a/census/src/main/java/io/grpc/census/CensusStatsModule.java +++ b/census/src/main/java/io/grpc/census/CensusStatsModule.java @@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Channel; @@ -62,7 +63,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Provides factories for {@link StreamTracer} that records stats to Census. diff --git a/census/src/main/java/io/grpc/census/GrpcCensus.java b/census/src/main/java/io/grpc/census/GrpcCensus.java new file mode 100644 index 00000000000..c564c349ae4 --- /dev/null +++ b/census/src/main/java/io/grpc/census/GrpcCensus.java @@ -0,0 +1,176 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.census; + +import com.google.common.base.Stopwatch; +import com.google.common.base.Supplier; +import io.grpc.ClientInterceptor; +import io.grpc.ExperimentalApi; +import io.grpc.ManagedChannelBuilder; +import io.grpc.ServerBuilder; +import io.grpc.ServerStreamTracer; +import io.opencensus.trace.Tracing; + +/** + * The entrypoint for OpenCensus instrumentation functionality in gRPC. + * + *

GrpcCensus uses {@link io.opencensus.api.OpenCensus} APIs for instrumentation. + * + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/12178") +public final class GrpcCensus { + + private final boolean statsEnabled; + private final boolean tracingEnabled; + + private GrpcCensus(Builder builder) { + this.statsEnabled = builder.statsEnabled; + this.tracingEnabled = builder.tracingEnabled; + } + + /** + * Creates a new builder for {@link GrpcCensus}. + */ + public static Builder newBuilder() { + return new Builder(); + } + + private static final Supplier STOPWATCH_SUPPLIER = new Supplier() { + @Override + public Stopwatch get() { + return Stopwatch.createUnstarted(); + } + }; + + /** + * Configures a {@link ServerBuilder} to enable census stats and tracing. + * + * @param serverBuilder The server builder to configure. + * @return The configured server builder. + */ + public > T configureServerBuilder(T serverBuilder) { + if (statsEnabled) { + serverBuilder.addStreamTracerFactory(newServerStatsStreamTracerFactory()); + } + if (tracingEnabled) { + serverBuilder.addStreamTracerFactory(newServerTracingStreamTracerFactory()); + } + return serverBuilder; + } + + /** + * Configures a {@link ManagedChannelBuilder} to enable census stats and tracing. + * + * @param channelBuilder The channel builder to configure. + * @return The configured channel builder. + */ + public > T configureChannelBuilder(T channelBuilder) { + if (statsEnabled) { + channelBuilder.intercept(newClientStatsInterceptor()); + } + if (tracingEnabled) { + channelBuilder.intercept(newClientTracingInterceptor()); + } + return channelBuilder; + } + + /** + * Returns a {@link ClientInterceptor} with default stats implementation. + */ + private static ClientInterceptor newClientStatsInterceptor() { + CensusStatsModule censusStats = + new CensusStatsModule( + STOPWATCH_SUPPLIER, + true, + true, + true, + false, + true); + return censusStats.getClientInterceptor(); + } + + /** + * Returns a {@link ClientInterceptor} with default tracing implementation. + */ + private static ClientInterceptor newClientTracingInterceptor() { + CensusTracingModule censusTracing = + new CensusTracingModule( + Tracing.getTracer(), + Tracing.getPropagationComponent().getBinaryFormat()); + return censusTracing.getClientInterceptor(); + } + + /** + * Returns a {@link ServerStreamTracer.Factory} with default stats implementation. + */ + private static ServerStreamTracer.Factory newServerStatsStreamTracerFactory() { + CensusStatsModule censusStats = + new CensusStatsModule( + STOPWATCH_SUPPLIER, + true, + true, + true, + false, + true); + return censusStats.getServerTracerFactory(); + } + + /** + * Returns a {@link ServerStreamTracer.Factory} with default tracing implementation. + */ + private static ServerStreamTracer.Factory newServerTracingStreamTracerFactory() { + CensusTracingModule censusTracing = + new CensusTracingModule( + Tracing.getTracer(), + Tracing.getPropagationComponent().getBinaryFormat()); + return censusTracing.getServerTracerFactory(); + } + + /** + * Builder for {@link GrpcCensus}. + */ + public static final class Builder { + private boolean statsEnabled = true; + private boolean tracingEnabled = true; + + private Builder() { + } + + /** + * Disables stats collection. + */ + public Builder disableStats() { + this.statsEnabled = false; + return this; + } + + /** + * Disables tracing. + */ + public Builder disableTracing() { + this.tracingEnabled = false; + return this; + } + + /** + * Builds a new {@link GrpcCensus}. + */ + public GrpcCensus build() { + return new GrpcCensus(this); + } + } +} diff --git a/census/src/test/java/io/grpc/census/CensusModulesTest.java b/census/src/test/java/io/grpc/census/CensusModulesTest.java index 6ccaf78314f..9e0b4d935d3 100644 --- a/census/src/test/java/io/grpc/census/CensusModulesTest.java +++ b/census/src/test/java/io/grpc/census/CensusModulesTest.java @@ -56,6 +56,7 @@ import io.grpc.ClientInterceptors; import io.grpc.ClientStreamTracer; import io.grpc.Context; +import io.grpc.KnownLength; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.ServerCall; @@ -99,6 +100,7 @@ import io.opencensus.trace.Tracer; import io.opencensus.trace.propagation.BinaryFormat; import io.opencensus.trace.propagation.SpanContextParseException; +import java.io.IOException; import java.io.InputStream; import java.util.HashSet; import java.util.List; @@ -136,7 +138,7 @@ public class CensusModulesTest { ClientStreamTracer.StreamInfo.newBuilder() .setCallOptions(CallOptions.DEFAULT.withOption(NAME_RESOLUTION_DELAYED, 10L)).build(); - private static class StringInputStream extends InputStream { + private static class StringInputStream extends InputStream implements KnownLength { final String string; StringInputStream(String string) { @@ -149,6 +151,11 @@ public int read() { // passed to the InProcess server and consumed by MARSHALLER.parse(). throw new UnsupportedOperationException("Should not be called"); } + + @Override + public int available() throws IOException { + return string == null ? 0 : string.length(); + } } private static final MethodDescriptor.Marshaller MARSHALLER = diff --git a/compiler/BUILD.bazel b/compiler/BUILD.bazel index 753f485074e..e8a0571e134 100644 --- a/compiler/BUILD.bazel +++ b/compiler/BUILD.bazel @@ -1,4 +1,5 @@ load("@rules_cc//cc:defs.bzl", "cc_binary") +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") load("//:java_grpc_library.bzl", "java_rpc_toolchain") @@ -18,11 +19,11 @@ cc_binary( java_library( name = "java_grpc_library_deps__do_not_reference", + visibility = ["//xds:__pkg__"], exports = [ "//api", "//protobuf", "//stub", - "//stub:javax_annotation", artifact("com.google.code.findbugs:jsr305"), artifact("com.google.guava:guava"), "@com_google_protobuf//:protobuf_java", @@ -35,7 +36,6 @@ java_library( "//api", "//protobuf-lite", "//stub", - "//stub:javax_annotation", artifact("com.google.code.findbugs:jsr305"), artifact("com.google.guava:guava"), ], diff --git a/compiler/build.gradle b/compiler/build.gradle index 3c8e9358401..f970f629e19 100644 --- a/compiler/build.gradle +++ b/compiler/build.gradle @@ -76,6 +76,7 @@ model { aarch_64 { architecture "aarch_64" } s390_64 { architecture "s390_64" } loongarch_64 { architecture "loongarch_64" } + riscv_64 { architecture "riscv_64" } } components { @@ -84,6 +85,7 @@ model { 'x86_32', 'x86_64', 'ppcle_64', + 'riscv_64', 'aarch_64', 's390_64', 'loongarch_64' @@ -100,19 +102,26 @@ model { all { if (toolChain in Gcc || toolChain in Clang) { cppCompiler.define("GRPC_VERSION", version) - cppCompiler.args "--std=c++0x" + cppCompiler.args "--std=c++17" addEnvArgs("CXXFLAGS", cppCompiler.args) addEnvArgs("CPPFLAGS", cppCompiler.args) + if (project.hasProperty('buildUniversal') && + project.getProperty('buildUniversal').toBoolean() && + osdetector.os == "osx") { + cppCompiler.args "-arch", "arm64", "-arch", "x86_64" + linker.args "-arch", "arm64", "-arch", "x86_64" + } if (osdetector.os == "osx") { cppCompiler.args "-mmacosx-version-min=10.7", "-stdlib=libc++" + linker.args "-framework", "CoreFoundation" addLibraryIfNotLinked('protoc', linker.args) addLibraryIfNotLinked('protobuf', linker.args) } else if (osdetector.os == "windows") { linker.args "-static", "-lprotoc", "-lprotobuf", "-static-libgcc", "-static-libstdc++", "-s" - } else if (osdetector.arch == "ppcle_64") { - linker.args "-Wl,-Bstatic", "-lprotoc", "-lprotobuf", "-Wl,-Bdynamic", "-lpthread", "-s" - } else { + } else if (osdetector.arch == "ppcle_64") { + linker.args "-Wl,-Bstatic", "-lprotoc", "-lprotobuf", "-Wl,-Bdynamic", "-lpthread", "-s" + } else { // Link protoc, protobuf, libgcc and libstdc++ statically. // Link other (system) libraries dynamically. // Clang under OSX doesn't support these options. @@ -123,13 +132,15 @@ model { } else if (toolChain in VisualCpp) { usingVisualCpp = true cppCompiler.define("GRPC_VERSION", version) - cppCompiler.args "/EHsc", "/MT" + cppCompiler.args "/EHsc", "/MT", "/std:c++17" if (rootProject.hasProperty('vcProtobufInclude')) { cppCompiler.args "/I${rootProject.vcProtobufInclude}" - } - linker.args "libprotobuf.lib", "libprotoc.lib" + } + linker.args.add("libprotoc.lib") + linker.args.add("libprotobuf.lib") if (rootProject.hasProperty('vcProtobufLibs')) { - linker.args "/LIBPATH:${rootProject.vcProtobufLibs}" + String libsList = rootProject.property('vcProtobufLibs') as String + libsList.split(',').each() { lib -> linker.args.add(lib) } } } } @@ -144,15 +155,13 @@ sourceSets { dependencies { testImplementation project(':grpc-protobuf'), - project(':grpc-stub'), - libraries.javax.annotation + project(':grpc-stub') testLiteImplementation project(':grpc-protobuf-lite'), - project(':grpc-stub'), - libraries.javax.annotation + project(':grpc-stub') } tasks.named("compileTestJava").configure { - options.errorprone.excludedPaths = ".*/build/generated/source/proto/.*" + options.errorprone.excludedPaths = ".*/build/generated/sources/proto/.*" } tasks.named("compileTestLiteJava").configure { @@ -160,7 +169,7 @@ tasks.named("compileTestLiteJava").configure { options.compilerArgs += [ "-Xlint:-cast" ] - options.errorprone.excludedPaths = ".*/build/generated/source/proto/.*" + options.errorprone.excludedPaths = ".*/build/generated/sources/proto/.*" } tasks.named("checkstyleTestLite").configure { @@ -184,7 +193,11 @@ protobuf { inputs.file javaPluginPath } ofSourceSet('test').configureEach { - plugins { grpc {} } + plugins { + grpc { + option '@generated=javax' + } + } } ofSourceSet('testLite').configureEach { builtins { @@ -193,7 +206,6 @@ protobuf { plugins { grpc { option 'lite' - option '@generated=omit' } } } @@ -239,9 +251,10 @@ def checkArtifacts = tasks.register("checkArtifacts") { if (ret.exitValue != 0) { throw new GradleException("dumpbin exited with " + ret.exitValue) } - def dlls = os.toString() =~ /Image has the following dependencies:\s+(.*)\s+Summary/ - if (dlls[0][1] != "KERNEL32.dll") { - throw new Exception("unexpected dll deps: " + dlls[0][1]); + def dlls_match_results = os.toString() =~ /Image has the following dependencies:([\S\s]*)Summary/ + def dlls = dlls_match_results[0][1].trim().split("\\s+").sort() + if (dlls != ["KERNEL32.dll", "dbghelp.dll"]) { + throw new Exception("unexpected dll deps: " + dlls); } os.reset() ret = exec { diff --git a/compiler/check-artifact.sh b/compiler/check-artifact.sh index 4d0c2fa6286..83b41f50282 100755 --- a/compiler/check-artifact.sh +++ b/compiler/check-artifact.sh @@ -86,17 +86,17 @@ checkArch () fi fi elif [[ "$OS" == osx ]]; then - format="$(file -b "$1" | grep -o "[^ ]*$")" - echo Format=$format - if [[ "$ARCH" == x86_32 ]]; then - assertEq "$format" "i386" $LINENO - elif [[ "$ARCH" == x86_64 ]]; then - assertEq "$format" "x86_64" $LINENO - elif [[ "$ARCH" == aarch_64 ]]; then - assertEq "$format" "arm64" $LINENO - else - fail "Unsupported arch: $ARCH" + # For macOS, we now build a universal binary. We check that both + # required architectures are present. + format="$(lipo -archs "$1")" + echo "Architectures found: $format" + if ! echo "$format" | grep -q "x86_64"; then + fail "Universal binary is missing x86_64 architecture." + fi + if ! echo "$format" | grep -q "arm64"; then + fail "Universal binary is missing arm64 architecture." fi + echo "Universal binary check successful." else fail "Unsupported system: $OS" fi @@ -114,7 +114,7 @@ checkDependencies () white_list="KERNEL32\.dll\|msvcrt\.dll\|USER32\.dll" elif [[ "$OS" == linux ]]; then dump_cmd='objdump -x '"$1"' | grep "NEEDED"' - white_list="libpthread\.so\.0\|libstdc++\.so\.6\|libc\.so\.6" + white_list="libpthread\.so\.0\|libstdc++\.so\.6\|libc\.so\.6\|librt\.so\.1\|libm\.so\.6" if [[ "$ARCH" == x86_32 ]]; then white_list="${white_list}\|libm\.so\.6" elif [[ "$ARCH" == x86_64 ]]; then diff --git a/compiler/src/java_plugin/cpp/java_generator.cpp b/compiler/src/java_plugin/cpp/java_generator.cpp index 8693fad1b66..a81d54791b4 100644 --- a/compiler/src/java_plugin/cpp/java_generator.cpp +++ b/compiler/src/java_plugin/cpp/java_generator.cpp @@ -143,11 +143,24 @@ static std::set java_keywords = { "false", }; +// Methods on java.lang.Object that take no arguments. +static std::set java_object_methods = { + "clone", + "finalize", + "getClass", + "hashCode", + "notify", + "notifyAll", + "toString", + "wait", +}; + // Adjust a method name prefix identifier to follow the JavaBean spec: // - decapitalize the first letter // - remove embedded underscores & capitalize the following letter -// Finally, if the result is a reserved java keyword, append an underscore. -static std::string MixedLower(const std::string& word) { +// Finally, if the result is a reserved java keyword or an Object method, +// append an underscore. +static std::string MixedLower(std::string word, bool mangle_object_methods = false) { std::string w; w += tolower(word[0]); bool after_underscore = false; @@ -159,7 +172,9 @@ static std::string MixedLower(const std::string& word) { after_underscore = false; } } - if (java_keywords.find(w) != java_keywords.end()) { + if (java_keywords.find(w) != java_keywords.end() || + (mangle_object_methods && + java_object_methods.find(w) != java_object_methods.end())) { return w + "_"; } return w; @@ -169,7 +184,7 @@ static std::string MixedLower(const std::string& word) { // - An underscore is inserted where a lower case letter is followed by an // upper case letter. // - All letters are converted to upper case -static std::string ToAllUpperCase(const std::string& word) { +static std::string ToAllUpperCase(std::string word) { std::string w; for (size_t i = 0; i < word.length(); ++i) { w += toupper(word[i]); @@ -180,24 +195,25 @@ static std::string ToAllUpperCase(const std::string& word) { return w; } -static inline std::string LowerMethodName(const MethodDescriptor* method) { - return MixedLower(method->name()); +static inline std::string LowerMethodName(const MethodDescriptor* method, + bool mangle_object_methods = false) { + return MixedLower(std::string(method->name()), mangle_object_methods); } static inline std::string MethodPropertiesFieldName(const MethodDescriptor* method) { - return "METHOD_" + ToAllUpperCase(method->name()); + return "METHOD_" + ToAllUpperCase(std::string(method->name())); } static inline std::string MethodPropertiesGetterName(const MethodDescriptor* method) { - return MixedLower("get_" + method->name() + "_method"); + return MixedLower("get_" + std::string(method->name()) + "_method"); } static inline std::string MethodIdFieldName(const MethodDescriptor* method) { - return "METHODID_" + ToAllUpperCase(method->name()); + return "METHODID_" + ToAllUpperCase(std::string(method->name())); } static inline std::string MessageFullJavaName(const Descriptor* desc) { - return protobuf::compiler::java::ClassName(desc); + return protobuf::compiler::java::QualifiedClassName(desc); } // TODO(nmittler): Remove once protobuf includes javadoc methods in distribution. @@ -355,13 +371,15 @@ enum StubType { BLOCKING_CLIENT_IMPL = 5, FUTURE_CLIENT_IMPL = 6, ABSTRACT_CLASS = 7, - NONE = 8, + BLOCKING_V2_CLIENT_IMPL = 8, + NONE = 999, }; enum CallType { ASYNC_CALL = 0, BLOCKING_CALL = 1, - FUTURE_CALL = 2 + FUTURE_CALL = 2, + BLOCKING_V2_CALL = 3, }; // TODO(nmittler): Remove once protobuf includes javadoc methods in distribution. @@ -404,12 +422,15 @@ static void GrpcWriteServiceDocComment(Printer* printer, StubType type) { printer->Print("/**\n"); - std::map vars = {{"service", service->name()}}; + std::map vars = {{"service", std::string(service->name())}}; switch (type) { case ASYNC_CLIENT_IMPL: printer->Print(vars, " * A stub to allow clients to do asynchronous rpc calls to service $service$.\n"); break; case BLOCKING_CLIENT_IMPL: + printer->Print(vars, " * A stub to allow clients to do limited synchronous rpc calls to service $service$.\n"); + break; + case BLOCKING_V2_CLIENT_IMPL: printer->Print(vars, " * A stub to allow clients to do synchronous rpc calls to service $service$.\n"); break; case FUTURE_CLIENT_IMPL: @@ -515,7 +536,8 @@ static void PrintMethodFields( " .setResponseMarshaller($ProtoUtils$.marshaller(\n" " $output_type$.getDefaultInstance()))\n"); - (*vars)["proto_method_descriptor_supplier"] = service->name() + "MethodDescriptorSupplier"; + (*vars)["proto_method_descriptor_supplier"] + = std::string(service->name()) + "MethodDescriptorSupplier"; if (flavor == ProtoFlavor::NORMAL) { p->Print( *vars, @@ -555,6 +577,9 @@ static void PrintStubFactory( case BLOCKING_CLIENT_IMPL: stub_type_name = "Blocking"; break; + case BLOCKING_V2_CLIENT_IMPL: + stub_type_name = "BlockingV2"; + break; default: GRPC_CODEGEN_FAIL << "Cannot generate StubFactory for StubType: " << type; } @@ -575,7 +600,7 @@ static void PrintStub( const ServiceDescriptor* service, std::map* vars, Printer* p, StubType type) { - const std::string service_name = service->name(); + std::string service_name = std::string(service->name()); (*vars)["service_name"] = service_name; std::string stub_name = service_name; std::string stub_base_class_name = "AbstractStub"; @@ -597,6 +622,11 @@ static void PrintStub( stub_name += "BlockingStub"; stub_base_class_name = "AbstractBlockingStub"; break; + case BLOCKING_V2_CLIENT_IMPL: + call_type = BLOCKING_V2_CALL; + stub_name += "BlockingV2Stub"; + stub_base_class_name = "AbstractBlockingStub"; + break; case FUTURE_CLIENT_IMPL: call_type = FUTURE_CALL; stub_name += "FutureStub"; @@ -662,10 +692,12 @@ static void PrintStub( const MethodDescriptor* method = service->method(i); (*vars)["input_type"] = MessageFullJavaName(method->input_type()); (*vars)["output_type"] = MessageFullJavaName(method->output_type()); - (*vars)["lower_method_name"] = LowerMethodName(method); - (*vars)["method_method_name"] = MethodPropertiesGetterName(method); bool client_streaming = method->client_streaming(); bool server_streaming = method->server_streaming(); + bool mangle_object_methods = (call_type == BLOCKING_V2_CALL && client_streaming) + || (call_type == BLOCKING_CALL && client_streaming && server_streaming); + (*vars)["lower_method_name"] = LowerMethodName(method, mangle_object_methods); + (*vars)["method_method_name"] = MethodPropertiesGetterName(method); if (call_type == BLOCKING_CALL && client_streaming) { // Blocking client interface with client streaming is not available @@ -679,13 +711,17 @@ static void PrintStub( // Method signature p->Print("\n"); - // TODO(nmittler): Replace with WriteMethodDocComment once included by the protobuf distro. GrpcWriteMethodDocComment(p, method); if (method->options().deprecated()) { p->Print(*vars, "@$Deprecated$\n"); } + if ((call_type == BLOCKING_CALL && client_streaming && server_streaming) + || (call_type == BLOCKING_V2_CALL && (client_streaming || server_streaming))) { + p->Print(*vars, "@io.grpc.ExperimentalApi(\"https://github.com/grpc/grpc-java/issues/10918\")\n"); + } + if (!interface) { p->Print("public "); } else { @@ -695,7 +731,12 @@ static void PrintStub( case BLOCKING_CALL: GRPC_CODEGEN_CHECK(!client_streaming) << "Blocking client interface with client streaming is unavailable"; - if (server_streaming) { + if (client_streaming && server_streaming) { + p->Print( + *vars, + "$BlockingClientCall$<$input_type$, $output_type$>\n" + " $lower_method_name$()"); + } else if (server_streaming) { // Server streaming p->Print( *vars, @@ -708,6 +749,26 @@ static void PrintStub( "$output_type$ $lower_method_name$($input_type$ request)"); } break; + case BLOCKING_V2_CALL: + if (client_streaming) { // Both Bidi and Client Streaming + p->Print( + *vars, + "$BlockingClientCall$<$input_type$, $output_type$>\n" + " $lower_method_name$()"); + } else if (server_streaming) { + // Server streaming + p->Print( + *vars, + "$BlockingClientCall$\n" + " $lower_method_name$($input_type$ request)"); + } else { + // Simple RPC + (*vars)["throws_decl"] = " throws io.grpc.StatusException"; + p->Print( + *vars, + "$output_type$ $lower_method_name$($input_type$ request)$throws_decl$"); + } + break; case ASYNC_CALL: if (client_streaming) { // Bidirectional streaming or client streaming @@ -753,21 +814,47 @@ static void PrintStub( "$method_method_name$(), responseObserver);\n"); } } else if (!interface) { - switch (call_type) { + switch (call_type) { case BLOCKING_CALL: GRPC_CODEGEN_CHECK(!client_streaming) - << "Blocking client streaming interface is not available"; - if (server_streaming) { - (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingServerStreamingCall"; - (*vars)["params"] = "request"; - } else { - (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingUnaryCall"; - (*vars)["params"] = "request"; + << "Blocking client and bidi streaming interface are not available"; + if (server_streaming) { + (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingServerStreamingCall"; + (*vars)["params"] = "request"; + } else { + (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingUnaryCall"; + (*vars)["params"] = "request"; + } + p->Print( + *vars, + "return $calls_method$(\n" + " getChannel(), $method_method_name$(), getCallOptions(), $params$);\n"); + break; + case BLOCKING_V2_CALL: + if (client_streaming) { // client and bidi streaming + if (server_streaming) { + (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingBidiStreamingCall"; + } else { + (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingClientStreamingCall"; + } + p->Print( + *vars, + "return $calls_method$(\n" + " getChannel(), $method_method_name$(), getCallOptions());\n"); + } else { // server streaming and unary + (*vars)["params"] = "request"; + if (server_streaming) { + (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall"; + } else { + (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingV2UnaryCall"; + (*vars)["throws_decl"] = " throws io.grpc.StatusException"; + } + + p->Print( + *vars, + "return $calls_method$(\n" + " getChannel(), $method_method_name$(), getCallOptions(), $params$);\n"); } - p->Print( - *vars, - "return $calls_method$(\n" - " getChannel(), $method_method_name$(), getCallOptions(), $params$);\n"); break; case ASYNC_CALL: if (server_streaming) { @@ -804,7 +891,7 @@ static void PrintStub( "return $calls_method$(\n" " getChannel().newCall($method_method_name$(), getCallOptions()), request);\n"); break; - } + } } else { GRPC_CODEGEN_FAIL << "Do not create Stub interfaces"; } @@ -821,8 +908,7 @@ static void PrintAbstractClassStub( const ServiceDescriptor* service, std::map* vars, Printer* p) { - const std::string service_name = service->name(); - (*vars)["service_name"] = service_name; + (*vars)["service_name"] = service->name(); GrpcWriteServiceDocComment(p, service, ABSTRACT_CLASS); if (service->options().deprecated()) { @@ -956,14 +1042,15 @@ static void PrintGetServiceDescriptorMethod(const ServiceDescriptor* service, std::map* vars, Printer* p, ProtoFlavor flavor) { - (*vars)["service_name"] = service->name(); + std::string service_name = std::string(service->name()); + (*vars)["service_name"] = service_name; if (flavor == ProtoFlavor::NORMAL) { - (*vars)["proto_base_descriptor_supplier"] = service->name() + "BaseDescriptorSupplier"; - (*vars)["proto_file_descriptor_supplier"] = service->name() + "FileDescriptorSupplier"; - (*vars)["proto_method_descriptor_supplier"] = service->name() + "MethodDescriptorSupplier"; - (*vars)["proto_class_name"] = protobuf::compiler::java::ClassName(service->file()); + (*vars)["proto_base_descriptor_supplier"] = service_name + "BaseDescriptorSupplier"; + (*vars)["proto_file_descriptor_supplier"] = service_name + "FileDescriptorSupplier"; + (*vars)["proto_method_descriptor_supplier"] = service_name + "MethodDescriptorSupplier"; + (*vars)["proto_class_name"] = protobuf::compiler::java::QualifiedClassName(service->file()); p->Print( *vars, "private static abstract class $proto_base_descriptor_supplier$\n" @@ -1173,6 +1260,21 @@ static void PrintService(const ServiceDescriptor* service, p->Outdent(); p->Print("}\n\n"); + // TODO(nmittler): Replace with WriteDocComment once included by protobuf distro. + GrpcWriteDocComment(p, " Creates a new blocking-style stub that supports all types of calls " + "on the service"); + p->Print( + *vars, + "public static $service_name$BlockingV2Stub newBlockingV2Stub(\n" + " $Channel$ channel) {\n"); + p->Indent(); + PrintStubFactory(service, vars, p, BLOCKING_V2_CLIENT_IMPL); + p->Print( + *vars, + "return $service_name$BlockingV2Stub.newStub(factory, channel);\n"); + p->Outdent(); + p->Print("}\n\n"); + // TODO(nmittler): Replace with WriteDocComment once included by protobuf distro. GrpcWriteDocComment(p, " Creates a new blocking-style stub that supports unary and streaming " "output calls on the service"); @@ -1206,6 +1308,7 @@ static void PrintService(const ServiceDescriptor* service, PrintStub(service, vars, p, ASYNC_INTERFACE); PrintAbstractClassStub(service, vars, p); PrintStub(service, vars, p, ASYNC_CLIENT_IMPL); + PrintStub(service, vars, p, BLOCKING_V2_CLIENT_IMPL); PrintStub(service, vars, p, BLOCKING_CLIENT_IMPL); PrintStub(service, vars, p, FUTURE_CLIENT_IMPL); @@ -1257,6 +1360,7 @@ void GenerateService(const ServiceDescriptor* service, vars["RpcMethod"] = "io.grpc.stub.annotations.RpcMethod"; vars["MethodDescriptor"] = "io.grpc.MethodDescriptor"; vars["StreamObserver"] = "io.grpc.stub.StreamObserver"; + vars["BlockingClientCall"] = "io.grpc.stub.BlockingClientCall"; vars["Iterator"] = "java.util.Iterator"; vars["GrpcGenerated"] = "io.grpc.stub.annotations.GrpcGenerated"; vars["ListenableFuture"] = @@ -1280,7 +1384,7 @@ void GenerateService(const ServiceDescriptor* service, } std::string ServiceJavaPackage(const FileDescriptor* file) { - std::string result = protobuf::compiler::java::ClassName(file); + std::string result = protobuf::compiler::java::QualifiedClassName(file); size_t last_dot_pos = result.find_last_of('.'); if (last_dot_pos != std::string::npos) { result.resize(last_dot_pos); @@ -1291,7 +1395,7 @@ std::string ServiceJavaPackage(const FileDescriptor* file) { } std::string ServiceClassName(const ServiceDescriptor* service) { - return service->name() + "Grpc"; + return std::string(service->name()) + "Grpc"; } } // namespace java_grpc_generator diff --git a/compiler/src/java_plugin/cpp/java_plugin.cpp b/compiler/src/java_plugin/cpp/java_plugin.cpp index c3aec58ed8e..4b02d6e9884 100644 --- a/compiler/src/java_plugin/cpp/java_plugin.cpp +++ b/compiler/src/java_plugin/cpp/java_plugin.cpp @@ -23,6 +23,9 @@ #include "java_generator.h" #include +#if GOOGLE_PROTOBUF_VERSION >= 5027000 +#include +#endif #include #include #include @@ -55,7 +58,15 @@ class JavaGrpcGenerator : public protobuf::compiler::CodeGenerator { return protobuf::Edition::EDITION_PROTO2; } protobuf::Edition GetMaximumEdition() const override { +#if GOOGLE_PROTOBUF_VERSION >= 6032000 + return protobuf::Edition::EDITION_2024; +#else return protobuf::Edition::EDITION_2023; +#endif + } + std::vector GetFeatureExtensions() + const override { + return {GetExtensionReflection(pb::java)}; } #else uint64_t GetSupportedFeatures() const override { @@ -73,7 +84,7 @@ class JavaGrpcGenerator : public protobuf::compiler::CodeGenerator { java_grpc_generator::ProtoFlavor flavor = java_grpc_generator::ProtoFlavor::NORMAL; java_grpc_generator::GeneratedAnnotation generated_annotation = - java_grpc_generator::GeneratedAnnotation::JAVAX; + java_grpc_generator::GeneratedAnnotation::OMIT; bool disable_version = false; for (size_t i = 0; i < options.size(); i++) { diff --git a/compiler/src/test/golden/TestDeprecatedService.java.txt b/compiler/src/test/golden/TestDeprecatedService.java.txt index 04a7f2406b3..1c37c9a8af9 100644 --- a/compiler/src/test/golden/TestDeprecatedService.java.txt +++ b/compiler/src/test/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; *

*/ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.68.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.81.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated @@ -64,6 +64,21 @@ public final class TestDeprecatedServiceGrpc { return TestDeprecatedServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestDeprecatedServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestDeprecatedServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestDeprecatedServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestDeprecatedServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -169,6 +184,38 @@ public final class TestDeprecatedServiceGrpc { *
*/ @java.lang.Deprecated + public static final class TestDeprecatedServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestDeprecatedServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestDeprecatedServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestDeprecatedServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * An RPC method that has been deprecated and should generate with Java's @Deprecated annotation
+     * 
+ */ + @java.lang.Deprecated + public io.grpc.testing.compiler.Test.SimpleResponse deprecatedMethod(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getDeprecatedMethodMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestDeprecatedService. + *
+   * Test service that has been deprecated and should generate with Java's @Deprecated annotation
+   * 
+ */ + @java.lang.Deprecated public static final class TestDeprecatedServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestDeprecatedServiceBlockingStub( diff --git a/compiler/src/test/golden/TestService.java.txt b/compiler/src/test/golden/TestService.java.txt index d69abad7cbb..08eb2fb6ac3 100644 --- a/compiler/src/test/golden/TestService.java.txt +++ b/compiler/src/test/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.68.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.81.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { @@ -282,6 +282,21 @@ public final class TestServiceGrpc { return TestServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -546,6 +561,125 @@ public final class TestServiceGrpc { * Test service that supports all call types. * */ + public static final class TestServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * One request followed by one response.
+     * The server returns the client payload as-is.
+     * 
+ */ + public io.grpc.testing.compiler.Test.SimpleResponse unaryCall(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by a sequence of responses (streamed download).
+     * The server returns the payload with client desired type and sizes.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingOutputCall(io.grpc.testing.compiler.Test.StreamingOutputCallRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamingOutputCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A sequence of requests followed by one response (streamed upload).
+     * The server returns the aggregated size of client payload as the result.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingInputCall() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getStreamingInputCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests with each request served by the server immediately.
+     * As one request could lead to multiple responses, this interface
+     * demonstrates the idea of full bidirectionality.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + fullBidiCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getFullBidiCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests followed by a sequence of responses.
+     * The server buffers all the client requests and then serves them in order. A
+     * stream of responses are returned to the client when the server starts with
+     * first request.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + halfBidiCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getHalfBidiCallMethod(), getCallOptions()); + } + + /** + *
+     * An RPC method whose Java name collides with a keyword, and whose generated
+     * method should have a '_' appended.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + import_() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getImportMethod(), getCallOptions()); + } + + /** + *
+     * A unary call that is Safe.
+     * 
+ */ + public io.grpc.testing.compiler.Test.SimpleResponse safeCall(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSafeCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A unary call that is Idempotent.
+     * 
+ */ + public io.grpc.testing.compiler.Test.SimpleResponse idempotentCall(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getIdempotentCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestService. + *
+   * Test service that supports all call types.
+   * 
+ */ public static final class TestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestServiceBlockingStub( diff --git a/compiler/src/testLite/golden/TestDeprecatedService.java.txt b/compiler/src/testLite/golden/TestDeprecatedService.java.txt index 3a7dba9bbb5..89ea2e698bf 100644 --- a/compiler/src/testLite/golden/TestDeprecatedService.java.txt +++ b/compiler/src/testLite/golden/TestDeprecatedService.java.txt @@ -60,6 +60,21 @@ public final class TestDeprecatedServiceGrpc { return TestDeprecatedServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestDeprecatedServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestDeprecatedServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestDeprecatedServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestDeprecatedServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -165,6 +180,38 @@ public final class TestDeprecatedServiceGrpc { * */ @java.lang.Deprecated + public static final class TestDeprecatedServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestDeprecatedServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestDeprecatedServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestDeprecatedServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * An RPC method that has been deprecated and should generate with Java's @Deprecated annotation
+     * 
+ */ + @java.lang.Deprecated + public io.grpc.testing.compiler.Test.SimpleResponse deprecatedMethod(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getDeprecatedMethodMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestDeprecatedService. + *
+   * Test service that has been deprecated and should generate with Java's @Deprecated annotation
+   * 
+ */ + @java.lang.Deprecated public static final class TestDeprecatedServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestDeprecatedServiceBlockingStub( diff --git a/compiler/src/testLite/golden/TestService.java.txt b/compiler/src/testLite/golden/TestService.java.txt index f86fb50d7dc..4e9dfb8d682 100644 --- a/compiler/src/testLite/golden/TestService.java.txt +++ b/compiler/src/testLite/golden/TestService.java.txt @@ -271,6 +271,21 @@ public final class TestServiceGrpc { return TestServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -535,6 +550,125 @@ public final class TestServiceGrpc { * Test service that supports all call types. * */ + public static final class TestServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * One request followed by one response.
+     * The server returns the client payload as-is.
+     * 
+ */ + public io.grpc.testing.compiler.Test.SimpleResponse unaryCall(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by a sequence of responses (streamed download).
+     * The server returns the payload with client desired type and sizes.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingOutputCall(io.grpc.testing.compiler.Test.StreamingOutputCallRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamingOutputCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A sequence of requests followed by one response (streamed upload).
+     * The server returns the aggregated size of client payload as the result.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingInputCall() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getStreamingInputCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests with each request served by the server immediately.
+     * As one request could lead to multiple responses, this interface
+     * demonstrates the idea of full bidirectionality.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + fullBidiCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getFullBidiCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests followed by a sequence of responses.
+     * The server buffers all the client requests and then serves them in order. A
+     * stream of responses are returned to the client when the server starts with
+     * first request.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + halfBidiCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getHalfBidiCallMethod(), getCallOptions()); + } + + /** + *
+     * An RPC method whose Java name collides with a keyword, and whose generated
+     * method should have a '_' appended.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + import_() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getImportMethod(), getCallOptions()); + } + + /** + *
+     * A unary call that is Safe.
+     * 
+ */ + public io.grpc.testing.compiler.Test.SimpleResponse safeCall(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSafeCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A unary call that is Idempotent.
+     * 
+ */ + public io.grpc.testing.compiler.Test.SimpleResponse idempotentCall(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getIdempotentCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestService. + *
+   * Test service that supports all call types.
+   * 
+ */ public static final class TestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestServiceBlockingStub( diff --git a/context/BUILD.bazel b/context/BUILD.bazel index d0c4b04ce00..0a51dca24a9 100644 --- a/context/BUILD.bazel +++ b/context/BUILD.bazel @@ -1,3 +1,5 @@ +load("@rules_java//java:defs.bzl", "java_library") + java_library( name = "context", visibility = ["//visibility:public"], diff --git a/contextstorage/build.gradle b/contextstorage/build.gradle new file mode 100644 index 00000000000..b1e78ea0e17 --- /dev/null +++ b/contextstorage/build.gradle @@ -0,0 +1,35 @@ +plugins { + id "java-library" + id "maven-publish" + + id "ru.vyarus.animalsniffer" +} + +description = 'gRPC: ContextStorageOverride' + +dependencies { + api project(':grpc-api') + implementation libraries.opentelemetry.api + + testImplementation libraries.junit, + libraries.opentelemetry.sdk.testing, + libraries.assertj.core + testImplementation 'junit:junit:4.13.1'// opentelemetry.sdk.testing uses compileOnly for assertj + + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } +} + +tasks.named("jar").configure { + manifest { + attributes('Automatic-Module-Name': 'io.grpc.override') + } +} diff --git a/contextstorage/src/main/java/io/grpc/override/ContextStorageOverride.java b/contextstorage/src/main/java/io/grpc/override/ContextStorageOverride.java new file mode 100644 index 00000000000..41b24765de0 --- /dev/null +++ b/contextstorage/src/main/java/io/grpc/override/ContextStorageOverride.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.override; + +import io.grpc.Context; + +/** + * Including this class in your dependencies will override the default gRPC context storage using + * reflection. It is a bridge between {@link io.grpc.Context} and + * {@link io.opentelemetry.context.Context}, i.e. propagating io.grpc.context.Context also + * propagates io.opentelemetry.context, and propagating io.opentelemetry.context will also propagate + * io.grpc.context. + */ +public final class ContextStorageOverride extends Context.Storage { + + private final Context.Storage delegate = new OpenTelemetryContextStorage(); + + @Override + public Context doAttach(Context toAttach) { + return delegate.doAttach(toAttach); + } + + @Override + public void detach(Context toDetach, Context toRestore) { + delegate.detach(toDetach, toRestore); + } + + @Override + public Context current() { + return delegate.current(); + } +} diff --git a/contextstorage/src/main/java/io/grpc/override/OpenTelemetryContextStorage.java b/contextstorage/src/main/java/io/grpc/override/OpenTelemetryContextStorage.java new file mode 100644 index 00000000000..01356e9f406 --- /dev/null +++ b/contextstorage/src/main/java/io/grpc/override/OpenTelemetryContextStorage.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.override; + +import io.grpc.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A Context.Storage implementation that attaches io.grpc.context to OpenTelemetry's context and + * io.opentelemetry.context is also saved in the io.grpc.context. + * Bridge between {@link io.grpc.Context} and {@link io.opentelemetry.context.Context}. + */ +final class OpenTelemetryContextStorage extends Context.Storage { + private static final Logger logger = Logger.getLogger( + OpenTelemetryContextStorage.class.getName()); + + private static final io.grpc.Context.Key OTEL_CONTEXT_OVER_GRPC + = io.grpc.Context.key("otel-context-over-grpc"); + private static final Context.Key OTEL_SCOPE = Context.key("otel-scope"); + private static final ContextKey GRPC_CONTEXT_OVER_OTEL = + ContextKey.named("grpc-context-over-otel"); + + @Override + @SuppressWarnings("MustBeClosedChecker") + public Context doAttach(Context toAttach) { + io.grpc.Context previous = current(); + io.opentelemetry.context.Context otelContext = OTEL_CONTEXT_OVER_GRPC.get(toAttach); + if (otelContext == null) { + otelContext = io.opentelemetry.context.Context.current(); + } + Scope scope = otelContext.with(GRPC_CONTEXT_OVER_OTEL, toAttach).makeCurrent(); + return previous.withValue(OTEL_SCOPE, scope); + } + + @Override + public void detach(Context toDetach, Context toRestore) { + Scope scope = OTEL_SCOPE.get(toRestore); + if (scope == null) { + logger.log( + Level.SEVERE, "Detaching context which was not attached."); + } else { + scope.close(); + } + } + + @Override + public Context current() { + io.opentelemetry.context.Context otelCurrent = io.opentelemetry.context.Context.current(); + io.grpc.Context grpcCurrent = otelCurrent.get(GRPC_CONTEXT_OVER_OTEL); + if (grpcCurrent == null) { + grpcCurrent = Context.ROOT; + } + return grpcCurrent.withValue(OTEL_CONTEXT_OVER_GRPC, otelCurrent); + } +} diff --git a/contextstorage/src/test/java/io/grpc/override/OpenTelemetryContextStorageTest.java b/contextstorage/src/test/java/io/grpc/override/OpenTelemetryContextStorageTest.java new file mode 100644 index 00000000000..3c628964342 --- /dev/null +++ b/contextstorage/src/test/java/io/grpc/override/OpenTelemetryContextStorageTest.java @@ -0,0 +1,144 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.override; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import com.google.common.util.concurrent.SettableFuture; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class OpenTelemetryContextStorageTest { + @Rule + public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); + private Tracer tracerRule = openTelemetryRule.getOpenTelemetry().getTracer( + "context-storage-test"); + private final io.grpc.Context.Key username = io.grpc.Context.key("username"); + private final ContextKey password = ContextKey.named("password"); + + @Test + public void grpcContextPropagation() throws Exception { + final Span parentSpan = tracerRule.spanBuilder("test-context").startSpan(); + final SettableFuture spanPropagated = SettableFuture.create(); + final SettableFuture grpcContextPropagated = SettableFuture.create(); + final SettableFuture spanDetached = SettableFuture.create(); + final SettableFuture grpcContextDetached = SettableFuture.create(); + + io.grpc.Context grpcContext; + try (Scope scope = Context.current().with(parentSpan).makeCurrent()) { + grpcContext = io.grpc.Context.current().withValue(username, "jeff"); + } + new Thread(new Runnable() { + @Override + public void run() { + io.grpc.Context previous = grpcContext.attach(); + try { + grpcContextPropagated.set(username.get(io.grpc.Context.current())); + spanPropagated.set(Span.fromContext(io.opentelemetry.context.Context.current())); + } finally { + grpcContext.detach(previous); + spanDetached.set(Span.fromContext(io.opentelemetry.context.Context.current())); + grpcContextDetached.set(username.get(io.grpc.Context.current())); + } + } + }).start(); + Assert.assertEquals(spanPropagated.get(5, TimeUnit.SECONDS), parentSpan); + Assert.assertEquals(grpcContextPropagated.get(5, TimeUnit.SECONDS), "jeff"); + Assert.assertEquals(spanDetached.get(5, TimeUnit.SECONDS), Span.getInvalid()); + Assert.assertNull(grpcContextDetached.get(5, TimeUnit.SECONDS)); + } + + @Test + public void otelContextPropagation() throws Exception { + final SettableFuture grpcPropagated = SettableFuture.create(); + final AtomicReference otelPropagation = new AtomicReference<>(); + + io.grpc.Context grpcContext = io.grpc.Context.current().withValue(username, "jeff"); + io.grpc.Context previous = grpcContext.attach(); + Context original = Context.current().with(password, "valentine"); + try { + new Thread( + () -> { + try (Scope scope = original.makeCurrent()) { + otelPropagation.set(Context.current().get(password)); + grpcPropagated.set(username.get(io.grpc.Context.current())); + } + } + ).start(); + } finally { + grpcContext.detach(previous); + } + Assert.assertEquals(grpcPropagated.get(5, TimeUnit.SECONDS), "jeff"); + Assert.assertEquals(otelPropagation.get(), "valentine"); + } + + @Test + public void grpcOtelMix() { + io.grpc.Context grpcContext = io.grpc.Context.current().withValue(username, "jeff"); + Context otelContext = Context.current().with(password, "valentine"); + Assert.assertNull(username.get(io.grpc.Context.current())); + Assert.assertNull(Context.current().get(password)); + io.grpc.Context previous = grpcContext.attach(); + try { + assertEquals(username.get(io.grpc.Context.current()), "jeff"); + try (Scope scope = otelContext.makeCurrent()) { + Assert.assertEquals(Context.current().get(password), "valentine"); + assertNull(username.get(io.grpc.Context.current())); + + io.grpc.Context grpcContext2 = io.grpc.Context.current().withValue(username, "frank"); + io.grpc.Context previous2 = grpcContext2.attach(); + try { + assertEquals(username.get(io.grpc.Context.current()), "frank"); + Assert.assertEquals(Context.current().get(password), "valentine"); + } finally { + grpcContext2.detach(previous2); + } + assertNull(username.get(io.grpc.Context.current())); + Assert.assertEquals(Context.current().get(password), "valentine"); + } + } finally { + grpcContext.detach(previous); + } + Assert.assertNull(username.get(io.grpc.Context.current())); + Assert.assertNull(Context.current().get(password)); + } + + @Test + public void grpcContextDetachError() { + io.grpc.Context grpcContext = io.grpc.Context.current().withValue(username, "jeff"); + io.grpc.Context previous = grpcContext.attach(); + try { + previous.detach(grpcContext); + assertEquals(username.get(io.grpc.Context.current()), "jeff"); + } finally { + grpcContext.detach(previous); + } + } +} diff --git a/core/BUILD.bazel b/core/BUILD.bazel index 35c20628d0b..1a743ff9eda 100644 --- a/core/BUILD.bazel +++ b/core/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( @@ -17,7 +18,6 @@ java_library( srcs = glob([ "src/main/java/io/grpc/internal/*.java", ]), - javacopts = ["-Xep:DoNotCall:OFF"], # Remove once requiring Bazel 3.4.0+; allows non-final resources = glob([ "src/bazel-internal/resources/**", ]), diff --git a/core/build.gradle b/core/build.gradle index f8a95c37286..b320f326b41 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -1,6 +1,6 @@ buildscript { dependencies { - classpath 'com.google.guava:guava:30.0-android' + classpath 'com.google.guava:guava:33.4.8-android' } } @@ -39,8 +39,16 @@ dependencies { jmh project(':grpc-testing') - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("javadoc").configure { diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index bb346657d53..bce1820b482 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -21,7 +21,6 @@ import static io.grpc.internal.GrpcUtil.CONTENT_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; -import static java.lang.Math.max; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; @@ -44,9 +43,9 @@ import javax.annotation.Nullable; /** - * The abstract base class for {@link ClientStream} implementations. Extending classes only need to - * implement {@link #transportState()} and {@link #abstractClientStreamSink()}. Must only be called - * from the sending application thread. + * The abstract base class for {@link ClientStream} implementations. + * + *

Must only be called from the sending application thread. */ public abstract class AbstractClientStream extends AbstractStream implements ClientStream, MessageFramer.Sink { @@ -102,6 +101,7 @@ void writeFrame( */ private volatile boolean cancelled; + @SuppressWarnings("this-escape") protected AbstractClientStream( WritableBufferAllocator bufferAllocator, StatsTraceContext statsTraceCtx, @@ -114,7 +114,7 @@ protected AbstractClientStream( this.shouldBeCountedForInUse = GrpcUtil.shouldBeCountedForInUse(callOptions); this.useGet = useGet; if (!useGet) { - framer = new MessageFramer(this, bufferAllocator, statsTraceCtx); + this.framer = new MessageFramer(this, bufferAllocator, statsTraceCtx); this.headers = headers; } else { framer = new GetFramer(headers, statsTraceCtx); @@ -124,8 +124,7 @@ protected AbstractClientStream( @Override public void setDeadline(Deadline deadline) { headers.discardAll(TIMEOUT_KEY); - long effectiveTimeout = max(0, deadline.timeRemaining(TimeUnit.NANOSECONDS)); - headers.put(TIMEOUT_KEY, effectiveTimeout); + headers.put(TIMEOUT_KEY, deadline.timeRemaining(TimeUnit.NANOSECONDS)); } @Override diff --git a/core/src/main/java/io/grpc/internal/AbstractServerStream.java b/core/src/main/java/io/grpc/internal/AbstractServerStream.java index a535330f4b1..c468cba978a 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerStream.java @@ -75,10 +75,11 @@ protected interface Sink { private boolean outboundClosed; private boolean headersSent; + @SuppressWarnings("this-escape") protected AbstractServerStream( WritableBufferAllocator bufferAllocator, StatsTraceContext statsTraceCtx) { this.statsTraceCtx = Preconditions.checkNotNull(statsTraceCtx, "statsTraceCtx"); - framer = new MessageFramer(this, bufferAllocator, statsTraceCtx); + this.framer = new MessageFramer(this, bufferAllocator, statsTraceCtx); } @Override diff --git a/core/src/main/java/io/grpc/internal/AbstractStream.java b/core/src/main/java/io/grpc/internal/AbstractStream.java index 9efc488657b..9f5fb035dab 100644 --- a/core/src/main/java/io/grpc/internal/AbstractStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractStream.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Codec; import io.grpc.Compressor; import io.grpc.Decompressor; @@ -27,13 +28,16 @@ import io.perfmark.PerfMark; import io.perfmark.TaskCloseable; import java.io.InputStream; -import javax.annotation.concurrent.GuardedBy; +import java.util.logging.Level; +import java.util.logging.Logger; /** * The stream and stream state as used by the application. Must only be called from the sending * application thread. */ public abstract class AbstractStream implements Stream { + private static final Logger log = Logger.getLogger(AbstractStream.class.getName()); + /** The framer to use for sending messages. */ protected abstract Framer framer(); @@ -159,20 +163,21 @@ public abstract static class TransportState @GuardedBy("onReadyLock") private int onReadyThreshold; + @SuppressWarnings("this-escape") protected TransportState( int maxMessageSize, StatsTraceContext statsTraceCtx, TransportTracer transportTracer) { this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx"); this.transportTracer = checkNotNull(transportTracer, "transportTracer"); - rawDeframer = new MessageDeframer( + this.rawDeframer = new MessageDeframer( this, Codec.Identity.NONE, maxMessageSize, statsTraceCtx, transportTracer); // TODO(#7168): use MigratingThreadDeframer when enabling retry doesn't break. - deframer = rawDeframer; + deframer = this.rawDeframer; onReadyThreshold = DEFAULT_ONREADY_THRESHOLD; } @@ -371,6 +376,12 @@ private void notifyIfReady() { boolean doNotify; synchronized (onReadyLock) { doNotify = isReady(); + if (!doNotify && log.isLoggable(Level.FINEST)) { + log.log(Level.FINEST, + "Stream not ready so skip notifying listener.\n" + + "details: allocated/deallocated:{0}/{3}, sent queued: {1}, ready thresh: {2}", + new Object[] {allocated, numSentBytesQueued, onReadyThreshold, deallocated}); + } } if (doNotify) { listener().onReady(); diff --git a/core/src/main/java/io/grpc/internal/AuthorityVerifier.java b/core/src/main/java/io/grpc/internal/AuthorityVerifier.java new file mode 100644 index 00000000000..e6164a7dc4d --- /dev/null +++ b/core/src/main/java/io/grpc/internal/AuthorityVerifier.java @@ -0,0 +1,24 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import io.grpc.Status; + +/** Verifier for the outgoing authority pseudo-header against peer cert. */ +public interface AuthorityVerifier { + Status verifyAuthority(String authority); +} diff --git a/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java b/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java index a382227fd6c..dcefa8f8351 100644 --- a/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java +++ b/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java @@ -19,17 +19,15 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.MoreObjects; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.FixedResultPicker; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; -import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; -import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; @@ -40,12 +38,10 @@ import java.util.Map; import javax.annotation.Nullable; -// TODO(creamsoup) fully deprecate LoadBalancer.ATTR_LOAD_BALANCING_CONFIG -@SuppressWarnings("deprecation") -public final class AutoConfiguredLoadBalancerFactory { +public final class AutoConfiguredLoadBalancerFactory extends LoadBalancerProvider { private final LoadBalancerRegistry registry; - private final String defaultPolicy; + private final LoadBalancerProvider defaultProvider; public AutoConfiguredLoadBalancerFactory(String defaultPolicy) { this(LoadBalancerRegistry.getDefaultRegistry(), defaultPolicy); @@ -54,47 +50,34 @@ public AutoConfiguredLoadBalancerFactory(String defaultPolicy) { @VisibleForTesting AutoConfiguredLoadBalancerFactory(LoadBalancerRegistry registry, String defaultPolicy) { this.registry = checkNotNull(registry, "registry"); - this.defaultPolicy = checkNotNull(defaultPolicy, "defaultPolicy"); + LoadBalancerProvider provider = + registry.getProvider(checkNotNull(defaultPolicy, "defaultPolicy")); + if (provider == null) { + Status status = Status.INTERNAL.withDescription("Could not find policy '" + defaultPolicy + + "'. Make sure its implementation is either registered to LoadBalancerRegistry or" + + " included in META-INF/services/io.grpc.LoadBalancerProvider from your jar files."); + provider = new FixedPickerLoadBalancerProvider( + ConnectivityState.TRANSIENT_FAILURE, + new LoadBalancer.FixedResultPicker(PickResult.withError(status)), + status); + } + this.defaultProvider = provider; } + @Override public AutoConfiguredLoadBalancer newLoadBalancer(Helper helper) { return new AutoConfiguredLoadBalancer(helper); } - private static final class NoopLoadBalancer extends LoadBalancer { - - @Override - @Deprecated - @SuppressWarnings("InlineMeSuggester") - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { - } - - @Override - public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - return Status.OK; - } - - @Override - public void handleNameResolutionError(Status error) {} - - @Override - public void shutdown() {} - } - @VisibleForTesting - public final class AutoConfiguredLoadBalancer { + public final class AutoConfiguredLoadBalancer extends LoadBalancer { private final Helper helper; private LoadBalancer delegate; private LoadBalancerProvider delegateProvider; AutoConfiguredLoadBalancer(Helper helper) { this.helper = helper; - delegateProvider = registry.getProvider(defaultPolicy); - if (delegateProvider == null) { - throw new IllegalStateException("Could not find policy '" + defaultPolicy - + "'. Make sure its implementation is either registered to LoadBalancerRegistry or" - + " included in META-INF/services/io.grpc.LoadBalancerProvider from your jar files."); - } + this.delegateProvider = defaultProvider; delegate = delegateProvider.newLoadBalancer(helper); } @@ -102,29 +85,20 @@ public final class AutoConfiguredLoadBalancer { * Returns non-OK status if the delegate rejects the resolvedAddresses (e.g. if it does not * support an empty list). */ - Status tryAcceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { PolicySelection policySelection = (PolicySelection) resolvedAddresses.getLoadBalancingPolicyConfig(); if (policySelection == null) { - LoadBalancerProvider defaultProvider; - try { - defaultProvider = getProviderOrThrow(defaultPolicy, "using default policy"); - } catch (PolicyException e) { - Status s = Status.INTERNAL.withDescription(e.getMessage()); - helper.updateBalancingState(ConnectivityState.TRANSIENT_FAILURE, new FailingPicker(s)); - delegate.shutdown(); - delegateProvider = null; - delegate = new NoopLoadBalancer(); - return Status.OK; - } policySelection = new PolicySelection(defaultProvider, /* config= */ null); } if (delegateProvider == null || !policySelection.provider.getPolicyName().equals(delegateProvider.getPolicyName())) { - helper.updateBalancingState(ConnectivityState.CONNECTING, new EmptyPicker()); + helper.updateBalancingState( + ConnectivityState.CONNECTING, new FixedResultPicker(PickResult.withNoResult())); delegate.shutdown(); delegateProvider = policySelection.provider; LoadBalancer old = delegate; @@ -147,20 +121,24 @@ Status tryAcceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { .build()); } - void handleNameResolutionError(Status error) { + @Override + public void handleNameResolutionError(Status error) { getDelegate().handleNameResolutionError(error); } + @Override @Deprecated - void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { + public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { getDelegate().handleSubchannelState(subchannel, stateInfo); } - void requestConnection() { + @Override + public void requestConnection() { getDelegate().requestConnection(); } - void shutdown() { + @Override + public void shutdown() { delegate.shutdown(); delegate = null; } @@ -181,16 +159,6 @@ LoadBalancerProvider getDelegateProvider() { } } - private LoadBalancerProvider getProviderOrThrow(String policy, String choiceReason) - throws PolicyException { - LoadBalancerProvider provider = registry.getProvider(policy); - if (provider == null) { - throw new PolicyException( - "Trying to load '" + policy + "' because " + choiceReason + ", but it's unavailable"); - } - return provider; - } - /** * Parses first available LoadBalancer policy from service config. Available LoadBalancer should * be registered to {@link LoadBalancerRegistry}. If the first available LoadBalancer policy is @@ -211,8 +179,11 @@ private LoadBalancerProvider getProviderOrThrow(String policy, String choiceReas * * @return the parsed {@link PolicySelection}, or {@code null} if no selection could be made. */ + // TODO(ejona): The Provider API doesn't allow null, but ScParser can handle this and it will need + // tweaking to ManagedChannelImpl.defaultServiceConfig to fix. @Nullable - ConfigOrError parseLoadBalancerPolicy(Map serviceConfig) { + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map serviceConfig) { try { List loadBalancerConfigs = null; if (serviceConfig != null) { @@ -230,38 +201,18 @@ ConfigOrError parseLoadBalancerPolicy(Map serviceConfig) { } } - @VisibleForTesting - static final class PolicyException extends Exception { - private static final long serialVersionUID = 1L; - - private PolicyException(String msg) { - super(msg); - } + @Override + public boolean isAvailable() { + return true; } - private static final class EmptyPicker extends SubchannelPicker { - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withNoResult(); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(EmptyPicker.class).toString(); - } + @Override + public int getPriority() { + return 5; } - private static final class FailingPicker extends SubchannelPicker { - private final Status failure; - - FailingPicker(Status failure) { - this.failure = failure; - } - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withError(failure); - } + @Override + public String getPolicyName() { + return "auto_configured_internal"; } } diff --git a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java index 42631851974..97a74bda97e 100644 --- a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java @@ -19,6 +19,7 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkNotNull; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallCredentials; import io.grpc.CallCredentials.RequestInfo; @@ -38,7 +39,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicInteger; -import javax.annotation.concurrent.GuardedBy; final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory { private final ClientTransportFactory delegate; diff --git a/core/src/main/java/io/grpc/internal/CertificateUtils.java b/core/src/main/java/io/grpc/internal/CertificateUtils.java new file mode 100644 index 00000000000..130a435bb1a --- /dev/null +++ b/core/src/main/java/io/grpc/internal/CertificateUtils.java @@ -0,0 +1,106 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.List; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; + +/** + * Contains certificate/key PEM file utility method(s) for internal usage. + */ +public final class CertificateUtils { + private static final Class x509ExtendedTrustManagerClass; + + static { + Class x509ExtendedTrustManagerClass1; + try { + x509ExtendedTrustManagerClass1 = Class.forName("javax.net.ssl.X509ExtendedTrustManager"); + } catch (ClassNotFoundException e) { + x509ExtendedTrustManagerClass1 = null; + // Will disallow per-rpc authority override via call option. + } + x509ExtendedTrustManagerClass = x509ExtendedTrustManagerClass1; + } + + /** + * Creates X509TrustManagers using the provided CA certs. + */ + public static TrustManager[] createTrustManager(byte[] rootCerts) + throws GeneralSecurityException { + InputStream rootCertsStream = new ByteArrayInputStream(rootCerts); + try { + return CertificateUtils.createTrustManager(rootCertsStream); + } finally { + GrpcUtil.closeQuietly(rootCertsStream); + } + } + + /** + * Creates X509TrustManagers using the provided input stream of CA certs. + */ + public static TrustManager[] createTrustManager(InputStream rootCerts) + throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts); + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + return trustManagerFactory.getTrustManagers(); + } + + public static X509TrustManager getX509ExtendedTrustManager(List trustManagers) { + if (x509ExtendedTrustManagerClass != null) { + for (TrustManager trustManager : trustManagers) { + if (x509ExtendedTrustManagerClass.isInstance(trustManager)) { + return (X509TrustManager) trustManager; + } + } + } + return null; + } + + private static X509Certificate[] getX509Certificates(InputStream inputStream) + throws CertificateException { + CertificateFactory factory = CertificateFactory.getInstance("X.509"); + Collection certs = factory.generateCertificates(inputStream); + return certs.toArray(new X509Certificate[0]); + } +} diff --git a/core/src/main/java/io/grpc/internal/ChannelTracer.java b/core/src/main/java/io/grpc/internal/ChannelTracer.java index 8c8243c9021..a9730a365cc 100644 --- a/core/src/main/java/io/grpc/internal/ChannelTracer.java +++ b/core/src/main/java/io/grpc/internal/ChannelTracer.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.ChannelLogger; import io.grpc.InternalChannelz.ChannelStats; import io.grpc.InternalChannelz.ChannelTrace; @@ -31,7 +32,6 @@ import java.util.logging.LogRecord; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Tracks a collections of channel tracing events for a channel/subchannel. diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index 07f2701d1c1..4b24b1eae3d 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -250,7 +250,8 @@ public void runInContext() { stream = clientStreamProvider.newStream(method, callOptions, headers, context); } else { ClientStreamTracer[] tracers = - GrpcUtil.getClientStreamTracers(callOptions, headers, 0, false); + GrpcUtil.getClientStreamTracers(callOptions, headers, 0, + false, false); String deadlineName = contextIsDeadlineSource ? "Context" : "CallOptions"; Long nameResolutionDelay = callOptions.getOption(NAME_RESOLUTION_DELAYED); String description = String.format( @@ -561,7 +562,11 @@ public Attributes getAttributes() { } private void closeObserver(Listener observer, Status status, Metadata trailers) { - observer.onClose(status, trailers); + try { + observer.onClose(status, trailers); + } catch (RuntimeException ex) { + log.log(Level.WARNING, "Exception thrown by onClose() in ClientCall", ex); + } } @Override diff --git a/core/src/main/java/io/grpc/internal/ClientTransport.java b/core/src/main/java/io/grpc/internal/ClientTransport.java index 98041cc6e79..fd0f30b8bf1 100644 --- a/core/src/main/java/io/grpc/internal/ClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ClientTransport.java @@ -22,6 +22,7 @@ import io.grpc.InternalInstrumented; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.Status; import java.util.concurrent.Executor; import javax.annotation.concurrent.ThreadSafe; @@ -90,6 +91,6 @@ interface PingCallback { * * @param cause the cause of the ping failure */ - void onFailure(Throwable cause); + void onFailure(Status cause); } } diff --git a/core/src/main/java/io/grpc/internal/ClientTransportFactory.java b/core/src/main/java/io/grpc/internal/ClientTransportFactory.java index d987f9d5068..6023fb14aa9 100644 --- a/core/src/main/java/io/grpc/internal/ClientTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/ClientTransportFactory.java @@ -18,16 +18,17 @@ import com.google.common.base.Objects; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.CallCredentials; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; import io.grpc.HttpConnectProxiedSocketAddress; +import io.grpc.MetricRecorder; import java.io.Closeable; import java.net.SocketAddress; import java.util.Collection; import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; /** Pre-configured factory for creating {@link ConnectionClientTransport} instances. */ @@ -91,6 +92,8 @@ final class ClientTransportOptions { private Attributes eagAttributes = Attributes.EMPTY; @Nullable private String userAgent; @Nullable private HttpConnectProxiedSocketAddress connectProxiedSocketAddr; + private MetricRecorder metricRecorder = new MetricRecorder() { + }; public ChannelLogger getChannelLogger() { return channelLogger; @@ -101,6 +104,15 @@ public ClientTransportOptions setChannelLogger(ChannelLogger channelLogger) { return this; } + public MetricRecorder getMetricRecorder() { + return metricRecorder; + } + + public ClientTransportOptions setMetricRecorder(MetricRecorder metricRecorder) { + this.metricRecorder = Preconditions.checkNotNull(metricRecorder, "metricRecorder"); + return this; + } + public String getAuthority() { return authority; } diff --git a/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java b/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java index 4407eb8a2a2..6cedb2caee9 100644 --- a/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java @@ -18,12 +18,10 @@ import java.io.IOException; import java.io.OutputStream; -import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.InvalidMarkException; import java.util.ArrayDeque; import java.util.Deque; -import java.util.Queue; import javax.annotation.Nullable; /** @@ -39,7 +37,6 @@ public class CompositeReadableBuffer extends AbstractReadableBuffer { private final Deque readableBuffers; private Deque rewindableBuffers; private int readableBytes; - private final Queue buffers = new ArrayDeque(2); private boolean marked; public CompositeReadableBuffer(int initialCapacity) { @@ -122,20 +119,6 @@ public int read(ReadableBuffer buffer, int length, byte[] dest, int offset) { } }; - private static final NoThrowReadOperation BYTE_BUF_OP = - new NoThrowReadOperation() { - @Override - public int read(ReadableBuffer buffer, int length, ByteBuffer dest, int unused) { - // Change the limit so that only lengthToCopy bytes are available. - int prevLimit = dest.limit(); - ((Buffer) dest).limit(dest.position() + length); - // Write the bytes and restore the original limit. - buffer.readBytes(dest); - ((Buffer) dest).limit(prevLimit); - return 0; - } - }; - private static final ReadOperation STREAM_OP = new ReadOperation() { @Override @@ -151,41 +134,11 @@ public void readBytes(byte[] dest, int destOffset, int length) { executeNoThrow(BYTE_ARRAY_OP, length, dest, destOffset); } - @Override - public void readBytes(ByteBuffer dest) { - executeNoThrow(BYTE_BUF_OP, dest.remaining(), dest, 0); - } - @Override public void readBytes(OutputStream dest, int length) throws IOException { execute(STREAM_OP, length, dest, 0); } - /** - * Reads {@code length} bytes from this buffer and writes them to the destination buffer. - * Increments the read position by {@code length}. If the required bytes are not readable, throws - * {@link IndexOutOfBoundsException}. - * - * @param dest the destination buffer to receive the bytes. - * @param length the number of bytes to be copied. - * @throws IndexOutOfBoundsException if required bytes are not readable - */ - public void readBytes(CompositeReadableBuffer dest, int length) { - checkReadable(length); - readableBytes -= length; - - while (length > 0) { - ReadableBuffer buffer = buffers.peek(); - if (buffer.readableBytes() > length) { - dest.addBuffer(buffer.readBytes(length)); - length = 0; - } else { - dest.addBuffer(buffers.poll()); - length -= buffer.readableBytes(); - } - } - } - @Override public ReadableBuffer readBytes(int length) { if (length <= 0) { diff --git a/core/src/main/java/io/grpc/internal/ConcurrentTimeProvider.java b/core/src/main/java/io/grpc/internal/ConcurrentTimeProvider.java new file mode 100644 index 00000000000..c82a68222b4 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/ConcurrentTimeProvider.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import java.util.concurrent.TimeUnit; + +/** + * {@link ConcurrentTimeProvider} resolves ConcurrentTimeProvider which implements + * {@link TimeProvider}. + */ + +final class ConcurrentTimeProvider implements TimeProvider { + + @Override + public long currentTimeNanos() { + return TimeUnit.MILLISECONDS.toNanos(System.currentTimeMillis()); + } +} diff --git a/core/src/main/java/io/grpc/internal/DelayedClientCall.java b/core/src/main/java/io/grpc/internal/DelayedClientCall.java index 92034e83f45..b568bb12c46 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientCall.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientCall.java @@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.ClientCall; import io.grpc.Context; @@ -38,7 +39,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * A call that queues requests before a real call is ready to be delegated to. @@ -64,6 +64,8 @@ public class DelayedClientCall extends ClientCall { * order, but also used if an error occurs before {@code realCall} is set. */ private Listener listener; + // No need to synchronize; start() synchronization provides a happens-before + private Metadata startHeaders; // Must hold {@code this} lock when setting. private ClientCall realCall; @GuardedBy("this") @@ -96,15 +98,13 @@ private boolean isAbeforeB(@Nullable Deadline a, @Nullable Deadline b) { private ScheduledFuture scheduleDeadlineIfNeeded( ScheduledExecutorService scheduler, @Nullable Deadline deadline) { Deadline contextDeadline = context.getDeadline(); - if (deadline == null && contextDeadline == null) { - return null; - } - long remainingNanos = Long.MAX_VALUE; - if (deadline != null) { + String deadlineName; + long remainingNanos; + if (deadline != null && isAbeforeB(deadline, contextDeadline)) { + deadlineName = "CallOptions"; remainingNanos = deadline.timeRemaining(NANOSECONDS); - } - - if (contextDeadline != null && contextDeadline.timeRemaining(NANOSECONDS) < remainingNanos) { + } else if (contextDeadline != null) { + deadlineName = "Context"; remainingNanos = contextDeadline.timeRemaining(NANOSECONDS); if (logger.isLoggable(Level.FINE)) { StringBuilder builder = @@ -121,29 +121,29 @@ private ScheduledFuture scheduleDeadlineIfNeeded( } logger.fine(builder.toString()); } - } - - long seconds = Math.abs(remainingNanos) / TimeUnit.SECONDS.toNanos(1); - long nanos = Math.abs(remainingNanos) % TimeUnit.SECONDS.toNanos(1); - final StringBuilder buf = new StringBuilder(); - String deadlineName = isAbeforeB(contextDeadline, deadline) ? "Context" : "CallOptions"; - if (remainingNanos < 0) { - buf.append("ClientCall started after "); - buf.append(deadlineName); - buf.append(" deadline was exceeded. Deadline has been exceeded for "); } else { - buf.append("Deadline "); - buf.append(deadlineName); - buf.append(" will be exceeded in "); + return null; } - buf.append(seconds); - buf.append(String.format(Locale.US, ".%09d", nanos)); - buf.append("s. "); /* Cancels the call if deadline exceeded prior to the real call being set. */ class DeadlineExceededRunnable implements Runnable { @Override public void run() { + long seconds = Math.abs(remainingNanos) / TimeUnit.SECONDS.toNanos(1); + long nanos = Math.abs(remainingNanos) % TimeUnit.SECONDS.toNanos(1); + StringBuilder buf = new StringBuilder(); + if (remainingNanos < 0) { + buf.append("ClientCall started after "); + buf.append(deadlineName); + buf.append(" deadline was exceeded. Deadline has been exceeded for "); + } else { + buf.append("Deadline "); + buf.append(deadlineName); + buf.append(" was exceeded after "); + } + buf.append(seconds); + buf.append(String.format(Locale.US, ".%09d", nanos)); + buf.append("s"); cancel( Status.DEADLINE_EXCEEDED.withDescription(buf.toString()), // We should not cancel the call if the realCall is set because there could be a @@ -163,13 +163,23 @@ public void run() { */ // When this method returns, passThrough is guaranteed to be true public final Runnable setCall(ClientCall call) { + Listener savedDelayedListener; synchronized (this) { // If realCall != null, then either setCall() or cancel() has been called. if (realCall != null) { return null; } setRealCall(checkNotNull(call, "call")); + // start() not yet called + if (delayedListener == null) { + assert pendingRunnables.isEmpty(); + pendingRunnables = null; + passThrough = true; + return null; + } + savedDelayedListener = this.delayedListener; } + internalStart(savedDelayedListener); return new ContextRunnable(context) { @Override public void runInContext() { @@ -178,8 +188,15 @@ public void runInContext() { }; } + private void internalStart(Listener listener) { + Metadata savedStartHeaders = this.startHeaders; + this.startHeaders = null; + context.run(() -> realCall.start(listener, savedStartHeaders)); + } + @Override public final void start(Listener listener, final Metadata headers) { + checkNotNull(headers, "headers"); checkState(this.listener == null, "already started"); Status savedError; boolean savedPassThrough; @@ -190,6 +207,7 @@ public final void start(Listener listener, final Metadata headers) { savedPassThrough = passThrough; if (!savedPassThrough) { listener = delayedListener = new DelayedListener<>(listener); + startHeaders = headers; } } if (savedError != null) { @@ -198,15 +216,7 @@ public final void start(Listener listener, final Metadata headers) { } if (savedPassThrough) { realCall.start(listener, headers); - } else { - final Listener finalListener = listener; - delayOrExecute(new Runnable() { - @Override - public void run() { - realCall.start(finalListener, headers); - } - }); - } + } // else realCall.start() will be called by setCall } // When this method returns, passThrough is guaranteed to be true @@ -255,6 +265,7 @@ public void run() { if (listenerToClose != null) { callExecutor.execute(new CloseListenerRunnable(listenerToClose, status)); } + internalStart(listenerToClose); // listener instance doesn't matter drainPendingCalls(); } callCancelled(); diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 6eebfdd0fae..5569e1eecf8 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -19,6 +19,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; import io.grpc.Context; @@ -39,7 +40,6 @@ import java.util.concurrent.Executor; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * A client transport that queues requests before a real transport is available. When {@link @@ -129,14 +129,27 @@ public final ClientStream newStream( if (state.shutdownStatus != null) { return new FailingClientStream(state.shutdownStatus, tracers); } + PickResult pickResult = null; if (state.lastPicker != null) { - PickResult pickResult = state.lastPicker.pickSubchannel(args); + pickResult = state.lastPicker.pickSubchannel(args); + callOptions = args.getCallOptions(); + // User code provided authority takes precedence over the LB provided one. + if (callOptions.getAuthority() == null + && pickResult.getAuthorityOverride() != null) { + callOptions = callOptions.withAuthority(pickResult.getAuthorityOverride()); + } ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult, callOptions.isWaitForReady()); if (transport != null) { - return transport.newStream( - args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions(), + ClientStream stream = transport.newStream( + args.getMethodDescriptor(), args.getHeaders(), callOptions, tracers); + // User code provided authority takes precedence over the LB provided one; this will be + // overwritten by ClientCallImpl if the application sets an authority override + if (pickResult.getAuthorityOverride() != null) { + stream.setAuthority(pickResult.getAuthorityOverride()); + } + return stream; } } // This picker's conclusion is "buffer". If there hasn't been a newer picker set (possible @@ -144,7 +157,7 @@ public final ClientStream newStream( synchronized (lock) { PickerState newerState = pickerState; if (state == newerState) { - return createPendingStream(args, tracers); + return createPendingStream(args, tracers, pickResult); } state = newerState; } @@ -159,9 +172,12 @@ public final ClientStream newStream( * schedule tasks on syncContext. */ @GuardedBy("lock") - private PendingStream createPendingStream( - PickSubchannelArgs args, ClientStreamTracer[] tracers) { + private PendingStream createPendingStream(PickSubchannelArgs args, ClientStreamTracer[] tracers, + PickResult pickResult) { PendingStream pendingStream = new PendingStream(args, tracers); + if (args.getCallOptions().isWaitForReady() && pickResult != null && pickResult.hasResult()) { + pendingStream.lastPickStatus = pickResult.getStatus(); + } pendingStreams.add(pendingStream); if (getPendingStreamsCount() == 1) { syncContext.executeLater(reportTransportInUse); @@ -185,8 +201,8 @@ public ListenableFuture getStats() { } /** - * Prevents creating any new streams. Buffered streams are not failed and may still proceed - * when {@link #reprocess} is called. The delayed transport will be terminated when there is no + * Prevents creating any new streams. Buffered streams are not failed and may still proceed + * when {@link #reprocess} is called. The delayed transport will be terminated when there is no * more buffered streams. */ @Override @@ -199,7 +215,7 @@ public final void shutdown(final Status status) { syncContext.executeLater(new Runnable() { @Override public void run() { - listener.transportShutdown(status); + listener.transportShutdown(status, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); } }); if (!hasPendingStreams() && reportTransportTerminated != null) { @@ -281,6 +297,9 @@ final void reprocess(@Nullable SubchannelPicker picker) { for (final PendingStream stream : toProcess) { PickResult pickResult = picker.pickSubchannel(stream.args); CallOptions callOptions = stream.args.getCallOptions(); + if (callOptions.isWaitForReady() && pickResult.hasResult()) { + stream.lastPickStatus = pickResult.getStatus(); + } final ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult, callOptions.isWaitForReady()); if (transport != null) { @@ -291,7 +310,7 @@ final void reprocess(@Nullable SubchannelPicker picker) { if (callOptions.getExecutor() != null) { executor = callOptions.getExecutor(); } - Runnable runnable = stream.createRealStream(transport); + Runnable runnable = stream.createRealStream(transport, pickResult.getAuthorityOverride()); if (runnable != null) { executor.execute(runnable); } @@ -306,7 +325,11 @@ final void reprocess(@Nullable SubchannelPicker picker) { if (!hasPendingStreams()) { return; } - pendingStreams.removeAll(toRemove); + // Avoid pendingStreams.removeAll() as it can degrade to calling toRemove.contains() for each + // element in pendingStreams. + for (PendingStream stream : toRemove) { + pendingStreams.remove(stream); + } // Because delayed transport is long-lived, we take this opportunity to down-size the // hashmap. if (pendingStreams.isEmpty()) { @@ -337,14 +360,16 @@ private class PendingStream extends DelayedStream { private final PickSubchannelArgs args; private final Context context = Context.current(); private final ClientStreamTracer[] tracers; + private volatile Status lastPickStatus; private PendingStream(PickSubchannelArgs args, ClientStreamTracer[] tracers) { + super("connecting_and_lb"); this.args = args; this.tracers = tracers; } /** Runnable may be null. */ - private Runnable createRealStream(ClientTransport transport) { + private Runnable createRealStream(ClientTransport transport, String authorityOverride) { ClientStream realStream; Context origContext = context.attach(); try { @@ -354,6 +379,13 @@ private Runnable createRealStream(ClientTransport transport) { } finally { context.detach(origContext); } + if (authorityOverride != null) { + // User code provided authority takes precedence over the LB provided one; this will be + // overwritten by an enqueud call from ClientCallImpl if the application sets an authority + // override. We must call the real stream directly because stream.start() has likely already + // been called on the delayed stream. + realStream.setAuthority(authorityOverride); + } return setStream(realStream); } @@ -386,6 +418,10 @@ protected void onEarlyCancellation(Status reason) { public void appendTimeoutInsight(InsightBuilder insight) { if (args.getCallOptions().isWaitForReady()) { insight.append("wait_for_ready"); + Status status = lastPickStatus; + if (status != null && !status.isOk()) { + insight.appendKeyValue("Last Pick Failure", status); + } } super.appendTimeoutInsight(insight); } diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index cc9dd0effc7..a2b1e963ac5 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -20,6 +20,8 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.CheckReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.Compressor; import io.grpc.Deadline; @@ -30,8 +32,6 @@ import java.io.InputStream; import java.util.ArrayList; import java.util.List; -import javax.annotation.CheckReturnValue; -import javax.annotation.concurrent.GuardedBy; /** * A stream that queues requests before the transport is available, and delegates to a real stream @@ -42,6 +42,7 @@ * necessary. */ class DelayedStream implements ClientStream { + private final String bufferContext; /** {@code true} once realStream is valid and all pending calls have been drained. */ private volatile boolean passThrough; /** @@ -64,6 +65,14 @@ class DelayedStream implements ClientStream { // No need to synchronize; start() synchronization provides a happens-before private List preStartPendingCalls = new ArrayList<>(); + /** + * Create a delayed stream with debug context {@code bufferContext}. The context is what this + * stream is delayed by (e.g., "connecting", "call_credentials"). + */ + public DelayedStream(String bufferContext) { + this.bufferContext = checkNotNull(bufferContext, "bufferContext"); + } + @Override public void setMaxInboundMessageSize(final int maxSize) { checkState(listener == null, "May only be called before start"); @@ -104,11 +113,13 @@ public void appendTimeoutInsight(InsightBuilder insight) { return; } if (realStream != null) { - insight.appendKeyValue("buffered_nanos", streamSetTimeNanos - startTimeNanos); + insight.appendKeyValue( + bufferContext + "_delay", "" + (streamSetTimeNanos - startTimeNanos) + "ns"); realStream.appendTimeoutInsight(insight); } else { - insight.appendKeyValue("buffered_nanos", System.nanoTime() - startTimeNanos); - insight.append("waiting_for_connection"); + insight.appendKeyValue( + bufferContext + "_delay", "" + (System.nanoTime() - startTimeNanos) + "ns"); + insight.append("was_still_waiting"); } } } @@ -208,7 +219,6 @@ private void delayOrExecute(Runnable runnable) { @Override public void setAuthority(final String authority) { - checkState(listener == null, "May only be called before start"); checkNotNull(authority, "authority"); preStartPendingCalls.add(new Runnable() { @Override diff --git a/core/src/main/java/io/grpc/internal/DisconnectError.java b/core/src/main/java/io/grpc/internal/DisconnectError.java new file mode 100644 index 00000000000..771024f106e --- /dev/null +++ b/core/src/main/java/io/grpc/internal/DisconnectError.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import javax.annotation.concurrent.Immutable; + +/** + * Represents the reason for a subchannel disconnection. + * Implementations are either the SimpleDisconnectError enum or the GoAwayDisconnectError class for + * dynamic ones. + */ +@Immutable +public interface DisconnectError { + /** + * Returns the string representation suitable for use as an error tag. + * + * @return The formatted error tag string. + */ + String toErrorString(); +} diff --git a/core/src/main/java/io/grpc/internal/DnsNameResolver.java b/core/src/main/java/io/grpc/internal/DnsNameResolver.java index df51d6f2c5c..1c1d95ed616 100644 --- a/core/src/main/java/io/grpc/internal/DnsNameResolver.java +++ b/core/src/main/java/io/grpc/internal/DnsNameResolver.java @@ -23,15 +23,14 @@ import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; -import com.google.common.base.Throwables; import com.google.common.base.Verify; import com.google.common.base.VerifyException; -import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.NameResolver; import io.grpc.ProxiedSocketAddress; import io.grpc.ProxyDetector; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.internal.SharedResourceHolder.Resource; import java.io.IOException; @@ -59,7 +58,7 @@ * A DNS-based {@link NameResolver}. * *

Each {@code A} or {@code AAAA} record emits an {@link EquivalentAddressGroup} in the list - * passed to {@link NameResolver.Listener2#onResult(ResolutionResult)}. + * passed to {@link NameResolver.Listener2#onResult2(ResolutionResult)}. * * @see DnsNameResolverProvider */ @@ -100,7 +99,7 @@ public class DnsNameResolver extends NameResolver { * not installed, the ttl value is {@code null} which falls back to {@link * #DEFAULT_NETWORK_CACHE_TTL_SECONDS gRPC default value}. * - *

For android, gRPC doesn't attempt to cache; this property value will be ignored. + *

For android, gRPC uses a fixed value; this property value will be ignored. */ @VisibleForTesting static final String NETWORKADDRESS_CACHE_TTL_PROPERTY = "networkaddress.cache.ttl"; @@ -133,10 +132,10 @@ public class DnsNameResolver extends NameResolver { private final String host; private final int port; - /** Executor that will be used if an Executor is not provide via {@link NameResolver.Args}. */ - private final Resource executorResource; + private final ObjectPool executorPool; private final long cacheTtlNanos; private final SynchronizationContext syncContext; + private final ServiceConfigParser serviceConfigParser; // Following fields must be accessed from syncContext private final Stopwatch stopwatch; @@ -144,10 +143,6 @@ public class DnsNameResolver extends NameResolver { private boolean shutdown; private Executor executor; - /** True if using an executor resource that should be released after use. */ - private final boolean usingExecutorResource; - private final ServiceConfigParser serviceConfigParser; - private boolean resolving; // The field must be accessed from syncContext, although the methods on an Listener2 can be called @@ -164,7 +159,7 @@ protected DnsNameResolver( checkNotNull(args, "args"); // TODO: if a DNS server is provided as nsAuthority, use it. // https://www.captechconsulting.com/blogs/accessing-the-dusty-corners-of-dns-with-java - this.executorResource = executorResource; + // Must prepend a "//" to the name when constructing a URI, otherwise it will be treated as an // opaque URI, thus the authority and host of the resulted URI would be null. URI nameUri = URI.create("//" + checkNotNull(name, "name")); @@ -178,11 +173,15 @@ protected DnsNameResolver( port = nameUri.getPort(); } this.proxyDetector = checkNotNull(args.getProxyDetector(), "proxyDetector"); + Executor offloadExecutor = args.getOffloadExecutor(); + if (offloadExecutor != null) { + this.executorPool = new FixedObjectPool<>(offloadExecutor); + } else { + this.executorPool = SharedResourcePool.forResource(executorResource); + } this.cacheTtlNanos = getNetworkAddressCacheTtlNanos(isAndroid); this.stopwatch = checkNotNull(stopwatch, "stopwatch"); this.syncContext = checkNotNull(args.getSynchronizationContext(), "syncContext"); - this.executor = args.getOffloadExecutor(); - this.usingExecutorResource = executor == null; this.serviceConfigParser = checkNotNull(args.getServiceConfigParser(), "serviceConfigParser"); } @@ -199,9 +198,7 @@ protected String getHost() { @Override public void start(Listener2 listener) { Preconditions.checkState(this.listener == null, "already started"); - if (usingExecutorResource) { - executor = SharedResourceHolder.get(executorResource); - } + executor = executorPool.getObject(); this.listener = checkNotNull(listener, "listener"); resolve(); } @@ -212,20 +209,8 @@ public void refresh() { resolve(); } - private List resolveAddresses() { - List addresses; - Exception addressesException = null; - try { - addresses = addressResolver.resolveAddress(host); - } catch (Exception e) { - addressesException = e; - Throwables.throwIfUnchecked(e); - throw new RuntimeException(e); - } finally { - if (addressesException != null) { - logger.log(Level.FINE, "Address resolution failure", addressesException); - } - } + private List resolveAddresses() throws Exception { + List addresses = addressResolver.resolveAddress(host); // Each address forms an EAG List servers = new ArrayList<>(addresses.size()); for (InetAddress inetAddr : addresses) { @@ -276,21 +261,19 @@ private EquivalentAddressGroup detectProxy() throws IOException { /** * Main logic of name resolution. */ - protected InternalResolutionResult doResolve(boolean forceTxt) { - InternalResolutionResult result = new InternalResolutionResult(); + protected ResolutionResult doResolve() { + ResolutionResult.Builder resultBuilder = ResolutionResult.newBuilder(); try { - result.addresses = resolveAddresses(); + resultBuilder.setAddressesOrError(StatusOr.fromValue(resolveAddresses())); } catch (Exception e) { - if (!forceTxt) { - result.error = - Status.UNAVAILABLE.withDescription("Unable to resolve host " + host).withCause(e); - return result; - } + logger.log(Level.FINE, "Address resolution failure", e); + resultBuilder.setAddressesOrError(StatusOr.fromStatus( + Status.UNAVAILABLE.withDescription("Unable to resolve host " + host).withCause(e))); } if (enableTxt) { - result.config = resolveServiceConfig(); + resultBuilder.setServiceConfig(resolveServiceConfig()); } - return result; + return resultBuilder.build(); } private final class Resolve implements Runnable { @@ -305,39 +288,32 @@ public void run() { if (logger.isLoggable(Level.FINER)) { logger.finer("Attempting DNS resolution of " + host); } - InternalResolutionResult result = null; + ResolutionResult result = null; try { EquivalentAddressGroup proxiedAddr = detectProxy(); - ResolutionResult.Builder resolutionResultBuilder = ResolutionResult.newBuilder(); if (proxiedAddr != null) { if (logger.isLoggable(Level.FINER)) { logger.finer("Using proxy address " + proxiedAddr); } - resolutionResultBuilder.setAddresses(Collections.singletonList(proxiedAddr)); + result = ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(Collections.singletonList(proxiedAddr))) + .build(); } else { - result = doResolve(false); - if (result.error != null) { - savedListener.onError(result.error); - return; - } - if (result.addresses != null) { - resolutionResultBuilder.setAddresses(result.addresses); - } - if (result.config != null) { - resolutionResultBuilder.setServiceConfig(result.config); - } - if (result.attributes != null) { - resolutionResultBuilder.setAttributes(result.attributes); - } + result = doResolve(); } + ResolutionResult savedResult = result; syncContext.execute(() -> { - savedListener.onResult2(resolutionResultBuilder.build()); + savedListener.onResult2(savedResult); }); } catch (IOException e) { - savedListener.onError( - Status.UNAVAILABLE.withDescription("Unable to resolve host " + host).withCause(e)); + syncContext.execute(() -> + savedListener.onResult2(ResolutionResult.newBuilder() + .setAddressesOrError( + StatusOr.fromStatus( + Status.UNAVAILABLE.withDescription( + "Unable to resolve host " + host).withCause(e))).build())); } finally { - final boolean succeed = result != null && result.error == null; + final boolean succeed = result != null && result.getAddressesOrError().hasValue(); syncContext.execute(new Runnable() { @Override public void run() { @@ -403,8 +379,8 @@ public void shutdown() { return; } shutdown = true; - if (executor != null && usingExecutorResource) { - executor = SharedResourceHolder.release(executorResource, executor); + if (executor != null) { + executor = executorPool.returnObject(executor); } } @@ -455,12 +431,14 @@ private static final List getHostnamesFromChoice(Map serviceC /** * Returns value of network address cache ttl property if not Android environment. For android, - * DnsNameResolver does not cache the dns lookup result. + * DnsNameResolver uses a fixed value. */ private static long getNetworkAddressCacheTtlNanos(boolean isAndroid) { if (isAndroid) { - // on Android, ignore dns cache. - return 0; + // On Android, use fixed value. If the network used changes this value shouldn't matter, as + // channel.enterIdle() should be called and this name resolver instance will be discarded. The + // new name resolver instance will then re-request. + return TimeUnit.SECONDS.toNanos(DEFAULT_NETWORK_CACHE_TTL_SECONDS); } String cacheTtlPropertyValue = System.getProperty(NETWORKADDRESS_CACHE_TTL_PROPERTY); @@ -482,7 +460,7 @@ private static long getNetworkAddressCacheTtlNanos(boolean isAndroid) { * Determines if a given Service Config choice applies, and if so, returns it. * * @see - * Service Config in DNS + * Service Config in DNS * @param choice The service config choice. * @return The service config object or {@code null} if this choice does not apply. */ @@ -537,18 +515,6 @@ private static long getNetworkAddressCacheTtlNanos(boolean isAndroid) { return sc; } - /** - * Used as a DNS-based name resolver's internal representation of resolution result. - */ - protected static final class InternalResolutionResult { - private Status error; - private List addresses; - private ConfigOrError config; - public Attributes attributes; - - private InternalResolutionResult() {} - } - /** * Describes a parsed SRV record. */ diff --git a/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java b/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java index c977fbb0cca..14b56f1a12a 100644 --- a/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java +++ b/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java @@ -21,25 +21,31 @@ import io.grpc.InternalServiceProviders; import io.grpc.NameResolver; import io.grpc.NameResolverProvider; +import io.grpc.Uri; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; import java.util.Collection; import java.util.Collections; +import java.util.List; /** * A provider for {@link DnsNameResolver}. * *

It resolves a target URI whose scheme is {@code "dns"}. The (optional) authority of the target - * URI is reserved for the address of alternative DNS server (not implemented yet). The path of the - * target URI, excluding the leading slash {@code '/'}, is treated as the host name and the optional - * port to be resolved by DNS. Example target URIs: + * URI is reserved for the address of alternative DNS server (not implemented yet). The target URI + * must be hierarchical and have exactly one path segment which will be interpreted as an RFC 2396 + * "server-based" authority and used as the "service authority" of the resulting {@link + * NameResolver}. The "host" part of this authority is the name to be resolved by DNS. The "port" + * part of this authority (if present) will become the port number for all {@link InetSocketAddress} + * produced by this resolver. For example: * *

    *
  • {@code "dns:///foo.googleapis.com:8080"} (using default DNS)
  • *
  • {@code "dns://8.8.8.8/foo.googleapis.com:8080"} (using alternative DNS (not implemented * yet))
  • - *
  • {@code "dns:///foo.googleapis.com"} (without port)
  • + *
  • {@code "dns:///foo.googleapis.com"} (output addresses will have port {@link + * NameResolver.Args#getDefaultPort()})
  • *
*/ public final class DnsNameResolverProvider extends NameResolverProvider { @@ -51,6 +57,7 @@ public final class DnsNameResolverProvider extends NameResolverProvider { @Override public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + // TODO(jdcormie): Remove once RFC 3986 migration is complete. if (SCHEME.equals(targetUri.getScheme())) { String targetPath = Preconditions.checkNotNull(targetUri.getPath(), "targetPath"); Preconditions.checkArgument(targetPath.startsWith("/"), @@ -68,6 +75,25 @@ public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { } } + @Override + public NameResolver newNameResolver(Uri targetUri, final NameResolver.Args args) { + if (SCHEME.equals(targetUri.getScheme())) { + List pathSegments = targetUri.getPathSegments(); + Preconditions.checkArgument(!pathSegments.isEmpty(), + "expected 1 path segment in target %s but found %s", targetUri, pathSegments); + String domainNameToResolve = pathSegments.get(0); + return new DnsNameResolver( + targetUri.getAuthority(), + domainNameToResolve, + args, + GrpcUtil.SHARED_CHANNEL_EXECUTOR, + Stopwatch.createUnstarted(), + IS_ANDROID); + } else { + return null; + } + } + @Override public String getDefaultScheme() { return SCHEME; diff --git a/core/src/main/java/io/grpc/internal/FailingClientTransport.java b/core/src/main/java/io/grpc/internal/FailingClientTransport.java index 5b31e6e5073..37194c46a29 100644 --- a/core/src/main/java/io/grpc/internal/FailingClientTransport.java +++ b/core/src/main/java/io/grpc/internal/FailingClientTransport.java @@ -55,7 +55,7 @@ public ClientStream newStream( public void ping(final PingCallback callback, Executor executor) { executor.execute(new Runnable() { @Override public void run() { - callback.onFailure(error.asException()); + callback.onFailure(error); } }); } diff --git a/core/src/main/java/io/grpc/internal/FixedPickerLoadBalancerProvider.java b/core/src/main/java/io/grpc/internal/FixedPickerLoadBalancerProvider.java new file mode 100644 index 00000000000..a632948bdb9 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/FixedPickerLoadBalancerProvider.java @@ -0,0 +1,80 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static java.util.Objects.requireNonNull; + +import io.grpc.ConnectivityState; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; +import io.grpc.Status; + +/** A LB provider whose LB always uses the same picker. */ +final class FixedPickerLoadBalancerProvider extends LoadBalancerProvider { + private final ConnectivityState state; + private final LoadBalancer.SubchannelPicker picker; + private final Status acceptAddressesStatus; + + public FixedPickerLoadBalancerProvider( + ConnectivityState state, LoadBalancer.SubchannelPicker picker, Status acceptAddressesStatus) { + this.state = requireNonNull(state, "state"); + this.picker = requireNonNull(picker, "picker"); + this.acceptAddressesStatus = requireNonNull(acceptAddressesStatus, "acceptAddressesStatus"); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "fixed_picker_lb_internal"; + } + + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return new FixedPickerLoadBalancer(helper); + } + + private final class FixedPickerLoadBalancer extends LoadBalancer { + private final Helper helper; + + public FixedPickerLoadBalancer(Helper helper) { + this.helper = requireNonNull(helper, "helper"); + } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + helper.updateBalancingState(state, picker); + return acceptAddressesStatus; + } + + @Override + public void handleNameResolutionError(Status error) { + helper.updateBalancingState(state, picker); + } + + @Override + public void shutdown() {} + } +} diff --git a/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java b/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java index 06d04b6de2d..7e690309647 100644 --- a/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java @@ -67,11 +67,6 @@ public void readBytes(byte[] dest, int destOffset, int length) { buf.readBytes(dest, destOffset, length); } - @Override - public void readBytes(ByteBuffer dest) { - buf.readBytes(dest); - } - @Override public void readBytes(OutputStream dest, int length) throws IOException { buf.readBytes(dest, length); diff --git a/core/src/main/java/io/grpc/internal/GoAwayDisconnectError.java b/core/src/main/java/io/grpc/internal/GoAwayDisconnectError.java new file mode 100644 index 00000000000..20c8c709932 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/GoAwayDisconnectError.java @@ -0,0 +1,64 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + + +import javax.annotation.concurrent.Immutable; + +/** + * Represents a dynamic disconnection due to an HTTP/2 GOAWAY frame. + * This class is immutable and holds the specific error code from the frame. + */ +@Immutable +public final class GoAwayDisconnectError implements DisconnectError { + private static final String ERROR_TAG = "GOAWAY"; + private final GrpcUtil.Http2Error errorCode; + + /** + * Creates a GoAway reason. + * + * @param errorCode The specific, non-null HTTP/2 error code (e.g., "NO_ERROR"). + */ + public GoAwayDisconnectError(GrpcUtil.Http2Error errorCode) { + if (errorCode == null) { + throw new NullPointerException("Http2Error cannot be null for GOAWAY"); + } + this.errorCode = errorCode; + } + + @Override + public String toErrorString() { + return ERROR_TAG + " " + errorCode.name(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GoAwayDisconnectError goAwayDisconnectError = (GoAwayDisconnectError) o; + return errorCode == goAwayDisconnectError.errorCode; + } + + @Override + public int hashCode() { + return errorCode.hashCode(); + } +} diff --git a/core/src/main/java/io/grpc/internal/GrpcAttributes.java b/core/src/main/java/io/grpc/internal/GrpcAttributes.java index da43ae14800..f95f9b9dab8 100644 --- a/core/src/main/java/io/grpc/internal/GrpcAttributes.java +++ b/core/src/main/java/io/grpc/internal/GrpcAttributes.java @@ -42,5 +42,8 @@ public final class GrpcAttributes { public static final Attributes.Key ATTR_CLIENT_EAG_ATTRS = Attributes.Key.create("io.grpc.internal.GrpcAttributes.clientEagAttrs"); + public static final Attributes.Key ATTR_AUTHORITY_VERIFIER = + Attributes.Key.create("io.grpc.internal.GrpcAttributes.authorityVerifier"); + private GrpcAttributes() {} } diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index a1fe34c2edc..deae5d831b8 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -24,7 +24,6 @@ import com.google.common.base.Preconditions; import com.google.common.base.Splitter; import com.google.common.base.Stopwatch; -import com.google.common.base.Strings; import com.google.common.base.Supplier; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -32,6 +31,7 @@ import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.InternalChannelz.SocketStats; +import io.grpc.InternalFeatureFlags; import io.grpc.InternalLogId; import io.grpc.InternalMetadata; import io.grpc.InternalMetadata.TrustedAsciiMarshaller; @@ -219,7 +219,7 @@ public byte[] parseAsciiString(byte[] serialized) { public static final Splitter ACCEPT_ENCODING_SPLITTER = Splitter.on(',').trimResults(); - public static final String IMPLEMENTATION_VERSION = "1.68.0-SNAPSHOT"; // CURRENT_GRPC_VERSION + public static final String IMPLEMENTATION_VERSION = "1.81.0-SNAPSHOT"; // CURRENT_GRPC_VERSION /** * The default timeout in nanos for a keepalive ping request. @@ -651,12 +651,14 @@ public Stopwatch get() { static class TimeoutMarshaller implements Metadata.AsciiMarshaller { @Override - public String toAsciiString(Long timeoutNanos) { + public String toAsciiString(Long timeoutNanosObject) { long cutoff = 100000000; + // Timeout checking is inherently racy. RPCs with timeouts in the past ideally don't even get + // here, but if the timeout is expired assume that happened recently and adjust it to the + // smallest allowed timeout + long timeoutNanos = Math.max(1, timeoutNanosObject); TimeUnit unit = TimeUnit.NANOSECONDS; - if (timeoutNanos < 0) { - throw new IllegalArgumentException("Timeout too small"); - } else if (timeoutNanos < cutoff) { + if (timeoutNanos < cutoff) { return timeoutNanos + "n"; } else if (timeoutNanos < cutoff * 1000L) { return unit.toMicros(timeoutNanos) + "u"; @@ -757,13 +759,15 @@ public ListenableFuture getStats() { /** Gets stream tracers based on CallOptions. */ public static ClientStreamTracer[] getClientStreamTracers( - CallOptions callOptions, Metadata headers, int previousAttempts, boolean isTransparentRetry) { + CallOptions callOptions, Metadata headers, int previousAttempts, boolean isTransparentRetry, + boolean isHedging) { List factories = callOptions.getStreamTracerFactories(); ClientStreamTracer[] tracers = new ClientStreamTracer[factories.size() + 1]; StreamInfo streamInfo = StreamInfo.newBuilder() .setCallOptions(callOptions) .setPreviousAttempts(previousAttempts) .setIsTransparentRetry(isTransparentRetry) + .setIsHedging(isHedging) .build(); for (int i = 0; i < factories.size(); i++) { tracers[i] = factories.get(i).newClientStreamTracer(streamInfo, headers); @@ -817,6 +821,31 @@ public static Status replaceInappropriateControlPlaneStatus(Status status) { + status.getDescription()).withCause(status.getCause()) : status; } + /** + * Returns a "clean" representation of a status code and description (not cause) like + * "UNAVAILABLE: The description". Should be similar to Status.formatThrowableMessage(). + */ + public static String statusToPrettyString(Status status) { + if (status.getDescription() == null) { + return status.getCode().toString(); + } else { + return status.getCode() + ": " + status.getDescription(); + } + } + + /** + * Create a status with contextual information, propagating details from a non-null status that + * contributed to the failure. For example, if UNAVAILABLE, "Couldn't load bar", and status + * "FAILED_PRECONDITION: Foo missing" were passed as arguments, then this method would produce the + * status "UNAVAILABLE: Couldn't load bar: FAILED_PRECONDITION: Foo missing" with a cause if the + * passed status had a cause. + */ + public static Status statusWithDetails(Status.Code code, String description, Status causeStatus) { + return code.toStatus() + .withDescription(description + ": " + statusToPrettyString(causeStatus)) + .withCause(causeStatus.getCause()); + } + /** * Checks whether the given item exists in the iterable. This is copied from Guava Collect's * {@code Iterables.contains()} because Guava Collect is not Android-friendly thus core can't @@ -929,18 +958,7 @@ public static String encodeAuthority(String authority) { } public static boolean getFlag(String envVarName, boolean enableByDefault) { - String envVar = System.getenv(envVarName); - if (envVar == null) { - envVar = System.getProperty(envVarName); - } - if (envVar != null) { - envVar = envVar.trim(); - } - if (enableByDefault) { - return Strings.isNullOrEmpty(envVar) || Boolean.parseBoolean(envVar); - } else { - return !Strings.isNullOrEmpty(envVar) && Boolean.parseBoolean(envVar); - } + return InternalFeatureFlags.getFlag(envVarName, enableByDefault); } diff --git a/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java b/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java index e92bb7a4af1..5560a1abb6d 100644 --- a/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java +++ b/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java @@ -140,6 +140,7 @@ protected void transportDataReceived(ReadableBuffer frame, boolean endOfStream) } } else { if (!headersReceived) { + frame.close(); http2ProcessingFailed( Status.INTERNAL.withDescription("headers not received before payload"), false, diff --git a/core/src/main/java/io/grpc/internal/Http2Ping.java b/core/src/main/java/io/grpc/internal/Http2Ping.java index 6104d876373..e3520295625 100644 --- a/core/src/main/java/io/grpc/internal/Http2Ping.java +++ b/core/src/main/java/io/grpc/internal/Http2Ping.java @@ -17,6 +17,8 @@ package io.grpc.internal; import com.google.common.base.Stopwatch; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.Status; import io.grpc.internal.ClientTransport.PingCallback; import java.util.LinkedHashMap; import java.util.Map; @@ -24,7 +26,6 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.concurrent.GuardedBy; /** * Represents an outstanding PING operation on an HTTP/2 channel. This can be used by HTTP/2-based @@ -62,7 +63,7 @@ public class Http2Ping { /** * If non-null, indicates the ping failed. */ - @GuardedBy("this") private Throwable failureCause; + @GuardedBy("this") private Status failureCause; /** * The round-trip time for the ping, in nanoseconds. This value is only meaningful when @@ -144,7 +145,7 @@ public boolean complete() { * * @param failureCause the cause of failure */ - public void failed(Throwable failureCause) { + public void failed(Status failureCause) { Map callbacks; synchronized (this) { if (completed) { @@ -167,7 +168,7 @@ public void failed(Throwable failureCause) { * @param executor the executor used to invoke the callback * @param cause the cause of failure */ - public static void notifyFailed(PingCallback callback, Executor executor, Throwable cause) { + public static void notifyFailed(PingCallback callback, Executor executor, Status cause) { doExecute(executor, asRunnable(callback, cause)); } @@ -203,7 +204,7 @@ public void run() { * failure. */ private static Runnable asRunnable(final ClientTransport.PingCallback callback, - final Throwable failureCause) { + final Status failureCause) { return new Runnable() { @Override public void run() { diff --git a/core/src/main/java/io/grpc/internal/InstantTimeProvider.java b/core/src/main/java/io/grpc/internal/InstantTimeProvider.java new file mode 100644 index 00000000000..12996163753 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/InstantTimeProvider.java @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.math.LongMath.saturatedAdd; + +import java.time.Instant; +import java.util.concurrent.TimeUnit; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; + +/** + * {@link InstantTimeProvider} resolves InstantTimeProvider which implements {@link TimeProvider}. + */ +final class InstantTimeProvider implements TimeProvider { + @Override + @IgnoreJRERequirement + public long currentTimeNanos() { + Instant now = Instant.now(); + long epochSeconds = now.getEpochSecond(); + return saturatedAdd(TimeUnit.SECONDS.toNanos(epochSeconds), now.getNano()); + } +} diff --git a/core/src/main/java/io/grpc/internal/InternalSubchannel.java b/core/src/main/java/io/grpc/internal/InternalSubchannel.java index 70e42e2f5f1..ce31921e316 100644 --- a/core/src/main/java/io/grpc/internal/InternalSubchannel.java +++ b/core/src/main/java/io/grpc/internal/InternalSubchannel.java @@ -45,8 +45,12 @@ import io.grpc.InternalInstrumented; import io.grpc.InternalLogId; import io.grpc.InternalWithLogId; +import io.grpc.LoadBalancer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MetricRecorder; +import io.grpc.NameResolver; +import io.grpc.SecurityLevel; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; @@ -76,7 +80,9 @@ final class InternalSubchannel implements InternalInstrumented, Tr private final InternalChannelz channelz; private final CallTracer callsTracer; private final ChannelTracer channelTracer; + private final MetricRecorder metricRecorder; private final ChannelLogger channelLogger; + private final boolean reconnectDisabled; private final List transportFilters; @@ -158,14 +164,20 @@ protected void handleNotInUse() { private Status shutdownReason; private volatile Attributes connectedAddressAttributes; - - InternalSubchannel(List addressGroups, String authority, String userAgent, - BackoffPolicy.Provider backoffPolicyProvider, - ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor, - Supplier stopwatchSupplier, SynchronizationContext syncContext, Callback callback, - InternalChannelz channelz, CallTracer callsTracer, ChannelTracer channelTracer, - InternalLogId logId, ChannelLogger channelLogger, - List transportFilters) { + private final SubchannelMetrics subchannelMetrics; + private final String target; + + InternalSubchannel(LoadBalancer.CreateSubchannelArgs args, String authority, String userAgent, + BackoffPolicy.Provider backoffPolicyProvider, + ClientTransportFactory transportFactory, + ScheduledExecutorService scheduledExecutor, + Supplier stopwatchSupplier, SynchronizationContext syncContext, + Callback callback, InternalChannelz channelz, CallTracer callsTracer, + ChannelTracer channelTracer, InternalLogId logId, + ChannelLogger channelLogger, List transportFilters, + String target, + MetricRecorder metricRecorder) { + List addressGroups = args.getAddresses(); Preconditions.checkNotNull(addressGroups, "addressGroups"); Preconditions.checkArgument(!addressGroups.isEmpty(), "addressGroups is empty"); checkListHasNoNulls(addressGroups, "addressGroups contains null entry"); @@ -180,6 +192,7 @@ protected void handleNotInUse() { this.scheduledExecutor = scheduledExecutor; this.connectingTimer = stopwatchSupplier.get(); this.syncContext = syncContext; + this.metricRecorder = metricRecorder; this.callback = callback; this.channelz = channelz; this.callsTracer = callsTracer; @@ -187,6 +200,9 @@ protected void handleNotInUse() { this.logId = Preconditions.checkNotNull(logId, "logId"); this.channelLogger = Preconditions.checkNotNull(channelLogger, "channelLogger"); this.transportFilters = transportFilters; + this.reconnectDisabled = args.getOption(LoadBalancer.DISABLE_SUBCHANNEL_RECONNECT_KEY); + this.target = target; + this.subchannelMetrics = new SubchannelMetrics(metricRecorder); } ChannelLogger getChannelLogger() { @@ -251,6 +267,7 @@ private void startNewTransport() { .setAuthority(eagChannelAuthority != null ? eagChannelAuthority : authority) .setEagAttributes(currentEagAttributes) .setUserAgent(userAgent) + .setMetricRecorder(metricRecorder) .setHttpConnectProxiedSocketAddress(proxiedAddr); TransportLogger transportLogger = new TransportLogger(); // In case the transport logs in the constructor, use the subchannel logId @@ -289,6 +306,11 @@ public void run() { } gotoState(ConnectivityStateInfo.forTransientFailure(status)); + + if (reconnectDisabled) { + return; + } + if (reconnectPolicy == null) { reconnectPolicy = backoffPolicyProvider.get(); } @@ -307,7 +329,7 @@ public void run() { } /** - * Immediately attempt to reconnect if the current state is TRANSIENT_FAILURE. Otherwise this + * Immediately attempt to reconnect if the current state is TRANSIENT_FAILURE. Otherwise, this * method has no effect. */ void resetConnectBackoff() { @@ -336,8 +358,12 @@ private void gotoState(final ConnectivityStateInfo newState) { if (state.getState() != newState.getState()) { Preconditions.checkState(state.getState() != SHUTDOWN, - "Cannot transition out of SHUTDOWN to " + newState); - state = newState; + "Cannot transition out of SHUTDOWN to %s", newState.getState()); + if (reconnectDisabled && newState.getState() == TRANSIENT_FAILURE) { + state = ConnectivityStateInfo.forNonError(IDLE); + } else { + state = newState; + } callback.onStateChange(InternalSubchannel.this, newState); } } @@ -579,6 +605,13 @@ public void run() { pendingTransport = null; connectedAddressAttributes = addressIndex.getCurrentEagAttributes(); gotoNonErrorState(READY); + subchannelMetrics.recordConnectionAttemptSucceeded(/* target= */ target, + /* backendService= */ getAttributeOrDefault( + addressIndex.getCurrentEagAttributes(), NameResolver.ATTR_BACKEND_SERVICE), + /* locality= */ getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), + EquivalentAddressGroup.ATTR_LOCALITY_NAME), + /* securityLevel= */ extractSecurityLevel(addressIndex.getCurrentEagAttributes() + .get(GrpcAttributes.ATTR_SECURITY_LEVEL))); } } }); @@ -590,7 +623,7 @@ public void transportInUse(boolean inUse) { } @Override - public void transportShutdown(final Status s) { + public void transportShutdown(final Status s, final DisconnectError disconnectError) { channelLogger.log( ChannelLogLevel.INFO, "{0} SHUTDOWN with {1}", transport.getLogId(), printShortStatus(s)); shutdownInitiated = true; @@ -604,11 +637,24 @@ public void run() { activeTransport = null; addressIndex.reset(); gotoNonErrorState(IDLE); + subchannelMetrics.recordDisconnection(/* target= */ target, + /* backendService= */ getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), + NameResolver.ATTR_BACKEND_SERVICE), + /* locality= */ getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), + EquivalentAddressGroup.ATTR_LOCALITY_NAME), + /* disconnectError= */ disconnectError.toErrorString(), + /* securityLevel= */ extractSecurityLevel(addressIndex.getCurrentEagAttributes() + .get(GrpcAttributes.ATTR_SECURITY_LEVEL))); } else if (pendingTransport == transport) { + subchannelMetrics.recordConnectionAttemptFailed(/* target= */ target, + /* backendService= */getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), + NameResolver.ATTR_BACKEND_SERVICE), + /* locality= */ getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), + EquivalentAddressGroup.ATTR_LOCALITY_NAME)); Preconditions.checkState(state.getState() == CONNECTING, "Expected state is CONNECTING, actual state is %s", state.getState()); addressIndex.increment(); - // Continue reconnect if there are still addresses to try. + // Continue to reconnect if there are still addresses to try. if (!addressIndex.isValid()) { pendingTransport = null; addressIndex.reset(); @@ -644,6 +690,27 @@ public void run() { } }); } + + private String extractSecurityLevel(SecurityLevel securityLevel) { + if (securityLevel == null) { + return "none"; + } + switch (securityLevel) { + case NONE: + return "none"; + case INTEGRITY: + return "integrity_only"; + case PRIVACY_AND_INTEGRITY: + return "privacy_and_integrity"; + default: + throw new IllegalArgumentException("Unknown SecurityLevel: " + securityLevel); + } + } + + private String getAttributeOrDefault(Attributes attributes, Attributes.Key key) { + String value = attributes.get(key); + return value == null ? "" : value; + } } // All methods are called in syncContext diff --git a/core/src/main/java/io/grpc/internal/JsonParser.java b/core/src/main/java/io/grpc/internal/JsonParser.java index 384d29754f0..14f78c09e72 100644 --- a/core/src/main/java/io/grpc/internal/JsonParser.java +++ b/core/src/main/java/io/grpc/internal/JsonParser.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import com.google.gson.stream.JsonReader; @@ -41,7 +42,8 @@ private JsonParser() {} /** * Parses a json string, returning either a {@code Map}, {@code List}, - * {@code String}, {@code Double}, {@code Boolean}, or {@code null}. + * {@code String}, {@code Double}, {@code Boolean}, or {@code null}. Fails if duplicate names + * found. */ public static Object parse(String raw) throws IOException { JsonReader jr = new JsonReader(new StringReader(raw)); @@ -81,6 +83,7 @@ private static Object parseRecursive(JsonReader jr) throws IOException { Map obj = new LinkedHashMap<>(); while (jr.hasNext()) { String name = jr.nextName(); + checkArgument(!obj.containsKey(name), "Duplicate key found: %s", name); Object value = parseRecursive(jr); obj.put(name, value); } @@ -105,4 +108,4 @@ private static Void parseJsonNull(JsonReader jr) throws IOException { jr.nextNull(); return null; } -} +} \ No newline at end of file diff --git a/core/src/main/java/io/grpc/internal/JsonUtil.java b/core/src/main/java/io/grpc/internal/JsonUtil.java index 44cb22abda5..6c9274702b6 100644 --- a/core/src/main/java/io/grpc/internal/JsonUtil.java +++ b/core/src/main/java/io/grpc/internal/JsonUtil.java @@ -356,23 +356,24 @@ private static int parseNanos(String value) throws ParseException { return result; } - private static final long NANOS_PER_SECOND = TimeUnit.SECONDS.toNanos(1); + private static final int NANOS_PER_SECOND = 1_000_000_000; /** * Copy of {@link com.google.protobuf.util.Durations#normalizedDuration}. */ - @SuppressWarnings("NarrowingCompoundAssignment") + // Math.addExact() requires Android API level 24 + @SuppressWarnings({"NarrowingCompoundAssignment", "InlineMeInliner"}) private static long normalizedDuration(long seconds, int nanos) { if (nanos <= -NANOS_PER_SECOND || nanos >= NANOS_PER_SECOND) { seconds = checkedAdd(seconds, nanos / NANOS_PER_SECOND); nanos %= NANOS_PER_SECOND; } if (seconds > 0 && nanos < 0) { - nanos += NANOS_PER_SECOND; // no overflow since nanos is negative (and we're adding) + nanos += NANOS_PER_SECOND; // no overflow— nanos is negative (and we're adding) seconds--; // no overflow since seconds is positive (and we're decrementing) } if (seconds < 0 && nanos > 0) { - nanos -= NANOS_PER_SECOND; // no overflow since nanos is positive (and we're subtracting) + nanos -= NANOS_PER_SECOND; // no overflow— nanos is positive (and we're subtracting) seconds++; // no overflow since seconds is negative (and we're incrementing) } if (!durationIsValid(seconds, nanos)) { diff --git a/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java b/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java index dd539e75a18..6480336470c 100644 --- a/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java +++ b/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java @@ -18,8 +18,8 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; import java.util.concurrent.TimeUnit; -import javax.annotation.CheckReturnValue; /** Monitors the client's PING usage to make sure the rate is permitted. */ public final class KeepAliveEnforcer { diff --git a/core/src/main/java/io/grpc/internal/KeepAliveManager.java b/core/src/main/java/io/grpc/internal/KeepAliveManager.java index 28e2a87276b..535b3a82524 100644 --- a/core/src/main/java/io/grpc/internal/KeepAliveManager.java +++ b/core/src/main/java/io/grpc/internal/KeepAliveManager.java @@ -22,11 +22,12 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Status; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; -import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; /** * Manages keepalive pings. @@ -262,9 +263,25 @@ public interface KeepAlivePinger { * Default client side {@link KeepAlivePinger}. */ public static final class ClientKeepAlivePinger implements KeepAlivePinger { - private final ConnectionClientTransport transport; - public ClientKeepAlivePinger(ConnectionClientTransport transport) { + + /** + * A {@link ClientTransport} that has life-cycle management. + * + */ + @ThreadSafe + public interface TransportWithDisconnectReason extends ClientTransport { + + /** + * Initiates a forceful shutdown in which preexisting and new calls are closed. Existing calls + * should be closed with the provided {@code reason} and {@code disconnectError}. + */ + void shutdownNow(Status reason, DisconnectError disconnectError); + } + + private final TransportWithDisconnectReason transport; + + public ClientKeepAlivePinger(TransportWithDisconnectReason transport) { this.transport = transport; } @@ -275,9 +292,10 @@ public void ping() { public void onSuccess(long roundTripTimeNanos) {} @Override - public void onFailure(Throwable cause) { + public void onFailure(Status cause) { transport.shutdownNow(Status.UNAVAILABLE.withDescription( - "Keepalive failed. The connection is likely gone")); + "Keepalive failed. The connection is likely gone"), + SimpleDisconnectError.CONNECTION_TIMED_OUT); } }, MoreExecutors.directExecutor()); } @@ -285,7 +303,8 @@ public void onFailure(Throwable cause) { @Override public void onPingTimeout() { transport.shutdownNow(Status.UNAVAILABLE.withDescription( - "Keepalive failed. The connection is likely gone")); + "Keepalive failed. The connection is likely gone"), + SimpleDisconnectError.CONNECTION_TIMED_OUT); } } } diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 07dcf9ee7bb..e423220e3ad 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -32,6 +32,7 @@ import com.google.common.base.Supplier; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallCredentials; import io.grpc.CallOptions; @@ -68,6 +69,7 @@ import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.LoadBalancerProvider; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; @@ -81,9 +83,9 @@ import io.grpc.NameResolverRegistry; import io.grpc.ProxyDetector; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; -import io.grpc.internal.AutoConfiguredLoadBalancerFactory.AutoConfiguredLoadBalancer; import io.grpc.internal.ClientCallImpl.ClientStreamProvider; import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; @@ -92,8 +94,8 @@ import io.grpc.internal.ManagedChannelServiceConfig.ServiceConfigConvertedSelector; import io.grpc.internal.RetriableStream.ChannelBufferMeter; import io.grpc.internal.RetriableStream.Throttle; -import io.grpc.internal.RetryingNameResolver.ResolutionResultListener; import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -116,7 +118,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** A communication channel for making outgoing RPCs. */ @@ -159,19 +160,17 @@ public Result selectConfig(PickSubchannelArgs args) { @Nullable private final String authorityOverride; private final NameResolverRegistry nameResolverRegistry; - private final URI targetUri; + private final UriWrapper targetUri; private final NameResolverProvider nameResolverProvider; private final NameResolver.Args nameResolverArgs; - private final AutoConfiguredLoadBalancerFactory loadBalancerFactory; + private final LoadBalancerProvider loadBalancerFactory; private final ClientTransportFactory originalTransportFactory; @Nullable private final ChannelCredentials originalChannelCreds; private final ClientTransportFactory transportFactory; - private final ClientTransportFactory oobTransportFactory; private final RestrictedScheduledExecutor scheduledExecutor; private final Executor executor; private final ObjectPool executorPool; - private final ObjectPool balancerRpcExecutorPool; private final ExecutorHolder balancerRpcExecutorHolder; private final ExecutorHolder offloadExecutorHolder; private final TimeProvider timeProvider; @@ -186,7 +185,12 @@ public void uncaughtException(Thread t, Throwable e) { Level.SEVERE, "[" + getLogId() + "] Uncaught exception in the SynchronizationContext. Panic!", e); - panic(e); + try { + panic(e); + } catch (Throwable anotherT) { + logger.log( + Level.SEVERE, "[" + getLogId() + "] Uncaught exception while panicking", anotherT); + } } }); @@ -222,11 +226,6 @@ public void uncaughtException(Thread t, Throwable e) { @Nullable private LbHelperImpl lbHelper; - // Must ONLY be assigned from updateSubchannelPicker(), which is called from syncContext. - // null if channel is in idle mode. - @Nullable - private volatile SubchannelPicker subchannelPicker; - // Must be accessed from the syncContext private boolean panicMode; @@ -240,9 +239,6 @@ public void uncaughtException(Thread t, Throwable e) { private Collection> pendingCalls; private final Object pendingCallsInUseObject = new Object(); - // Must be mutated from syncContext - private final Set oobChannels = new HashSet<>(1, .75f); - // reprocess() must be run from syncContext private final DelayedClientTransport delayedTransport; private final UncommittedRetriableStreamsRegistry uncommittedRetriableStreamsRegistry @@ -253,8 +249,7 @@ public void uncaughtException(Thread t, Throwable e) { // Channel's shutdown process: // 1. shutdown(): stop accepting new calls from applications // 1a shutdown <- true - // 1b subchannelPicker <- null - // 1c delayedTransport.shutdown() + // 1b delayedTransport.shutdown() // 2. delayedTransport terminated: stop stream-creation functionality // 2a terminating <- true // 2b loadBalancer.shutdown() @@ -313,9 +308,6 @@ private void maybeShutdownNowSubchannels() { for (InternalSubchannel subchannel : subchannels) { subchannel.shutdownNow(SHUTDOWN_NOW_STATUS); } - for (OobChannel oobChannel : oobChannels) { - oobChannel.getInternalSubchannel().shutdownNow(SHUTDOWN_NOW_STATUS); - } } } @@ -335,7 +327,6 @@ public void run() { builder.setTarget(target).setState(channelStateManager.getState()); List children = new ArrayList<>(); children.addAll(subchannels); - children.addAll(oobChannels); builder.setSubchannels(children); ret.set(builder.build()); } @@ -387,7 +378,6 @@ private void shutdownNameResolverAndLoadBalancer(boolean channelIsActive) { lbHelper.lb.shutdown(); lbHelper = null; } - subchannelPicker = null; } /** @@ -417,7 +407,7 @@ void exitIdleMode() { LbHelperImpl lbHelper = new LbHelperImpl(); lbHelper.lb = loadBalancerFactory.newLoadBalancer(lbHelper); // Delay setting lbHelper until fully initialized, since loadBalancerFactory is user code and - // may throw. We don't want to confuse our state, even if we will enter panic mode. + // may throw. We don't want to confuse our state, even if we enter panic mode. this.lbHelper = lbHelper; channelStateManager.gotoState(CONNECTING); @@ -481,7 +471,8 @@ public ClientStream newStream( // the delayed transport or a real transport will go in-use and cancel the idle timer. if (!retryEnabled) { ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( - callOptions, headers, 0, /* isTransparentRetry= */ false); + callOptions, headers, 0, /* isTransparentRetry= */ false, + /* isHedging= */false); Context origContext = context.attach(); try { return delayedTransport.newStream(method, headers, callOptions, tracers); @@ -521,10 +512,10 @@ void postCommit() { @Override ClientStream newSubstream( Metadata newHeaders, ClientStreamTracer.Factory factory, int previousAttempts, - boolean isTransparentRetry) { + boolean isTransparentRetry, boolean isHedgedStream) { CallOptions newOptions = callOptions.withStreamTracerFactory(factory); ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( - newOptions, newHeaders, previousAttempts, isTransparentRetry); + newOptions, newHeaders, previousAttempts, isTransparentRetry, isHedgedStream); Context origContext = context.attach(); try { return delayedTransport.newStream(method, newHeaders, newOptions, tracers); @@ -547,7 +538,7 @@ ClientStream newSubstream( ManagedChannelImpl( ManagedChannelImplBuilder builder, ClientTransportFactory clientTransportFactory, - URI targetUri, + UriWrapper targetUri, NameResolverProvider nameResolverProvider, BackoffPolicy.Provider backoffPolicyProvider, ObjectPool balancerRpcExecutorPool, @@ -565,8 +556,6 @@ ClientStream newSubstream( new ExecutorHolder(checkNotNull(builder.offloadExecutorPool, "offloadExecutorPool")); this.transportFactory = new CallCredentialsApplyingTransportFactory( clientTransportFactory, builder.callCredentials, this.offloadExecutorHolder); - this.oobTransportFactory = new CallCredentialsApplyingTransportFactory( - clientTransportFactory, null, this.offloadExecutorHolder); this.scheduledExecutor = new RestrictedScheduledExecutor(transportFactory.getScheduledExecutorService()); maxTraceEvents = builder.maxTraceEvents; @@ -588,8 +577,9 @@ ClientStream newSubstream( builder.maxHedgedAttempts, loadBalancerFactory); this.authorityOverride = builder.authorityOverride; - this.nameResolverArgs = - NameResolver.Args.newBuilder() + this.metricRecorder = new MetricRecorderImpl(builder.metricSinks, + MetricInstrumentRegistry.getDefaultRegistry()); + NameResolver.Args.Builder nameResolverArgsBuilder = NameResolver.Args.newBuilder() .setDefaultPort(builder.getDefaultPort()) .setProxyDetector(proxyDetector) .setSynchronizationContext(syncContext) @@ -598,11 +588,14 @@ ClientStream newSubstream( .setChannelLogger(channelLogger) .setOffloadExecutor(this.offloadExecutorHolder) .setOverrideAuthority(this.authorityOverride) - .build(); + .setMetricRecorder(this.metricRecorder) + .setNameResolverRegistry(builder.nameResolverRegistry); + builder.copyAllNameResolverCustomArgsTo(nameResolverArgsBuilder); + this.nameResolverArgs = nameResolverArgsBuilder.build(); this.nameResolver = getNameResolver( targetUri, authorityOverride, nameResolverProvider, nameResolverArgs); - this.balancerRpcExecutorPool = checkNotNull(balancerRpcExecutorPool, "balancerRpcExecutorPool"); - this.balancerRpcExecutorHolder = new ExecutorHolder(balancerRpcExecutorPool); + this.balancerRpcExecutorHolder = new ExecutorHolder( + checkNotNull(balancerRpcExecutorPool, "balancerRpcExecutorPool")); this.delayedTransport = new DelayedClientTransport(this.executor, this.syncContext); this.delayedTransport.start(delayedTransportListener); this.backoffPolicyProvider = backoffPolicyProvider; @@ -670,15 +663,13 @@ public CallTracer create() { } serviceConfigUpdated = true; } - this.metricRecorder = new MetricRecorderImpl(builder.metricSinks, - MetricInstrumentRegistry.getDefaultRegistry()); } @VisibleForTesting static NameResolver getNameResolver( - URI targetUri, @Nullable final String overrideAuthority, + UriWrapper targetUri, @Nullable final String overrideAuthority, NameResolverProvider provider, NameResolver.Args nameResolverArgs) { - NameResolver resolver = provider.newNameResolver(targetUri, nameResolverArgs); + NameResolver resolver = targetUri.newNameResolver(provider, nameResolverArgs); if (resolver == null) { throw new IllegalArgumentException("cannot create a NameResolver for " + targetUri); } @@ -686,11 +677,7 @@ static NameResolver getNameResolver( // We wrap the name resolver in a RetryingNameResolver to give it the ability to retry failures. // TODO: After a transition period, all NameResolver implementations that need retry should use // RetryingNameResolver directly and this step can be removed. - NameResolver usedNameResolver = new RetryingNameResolver(resolver, - new BackoffPolicyRetryScheduler(new ExponentialBackoffPolicy.Provider(), - nameResolverArgs.getScheduledExecutorService(), - nameResolverArgs.getSynchronizationContext()), - nameResolverArgs.getSynchronizationContext()); + NameResolver usedNameResolver = RetryingNameResolver.wrap(resolver, nameResolverArgs); if (overrideAuthority == null) { return usedNameResolver; @@ -778,30 +765,16 @@ void panic(final Throwable t) { return; } panicMode = true; - cancelIdleTimer(/* permanent= */ true); - shutdownNameResolverAndLoadBalancer(false); - final class PanicSubchannelPicker extends SubchannelPicker { - private final PickResult panicPickResult = - PickResult.withDrop( - Status.INTERNAL.withDescription("Panic! This is a bug!").withCause(t)); - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return panicPickResult; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(PanicSubchannelPicker.class) - .add("panicPickResult", panicPickResult) - .toString(); - } + try { + cancelIdleTimer(/* permanent= */ true); + shutdownNameResolverAndLoadBalancer(false); + } finally { + updateSubchannelPicker(new LoadBalancer.FixedResultPicker(PickResult.withDrop( + Status.INTERNAL.withDescription("Panic! This is a bug!").withCause(t)))); + realChannel.updateConfigSelector(null); + channelLogger.log(ChannelLogLevel.ERROR, "PANIC! Entering TRANSIENT_FAILURE"); + channelStateManager.gotoState(TRANSIENT_FAILURE); } - - updateSubchannelPicker(new PanicSubchannelPicker()); - realChannel.updateConfigSelector(null); - channelLogger.log(ChannelLogLevel.ERROR, "PANIC! Entering TRANSIENT_FAILURE"); - channelStateManager.gotoState(TRANSIENT_FAILURE); } @VisibleForTesting @@ -811,7 +784,6 @@ boolean isInPanicMode() { // Called from syncContext private void updateSubchannelPicker(SubchannelPicker newPicker) { - subchannelPicker = newPicker; delayedTransport.reprocess(newPicker); } @@ -956,7 +928,15 @@ void updateConfigSelector(@Nullable InternalConfigSelector config) { // Must run in SynchronizationContext. void onConfigError() { if (configSelector.get() == INITIAL_PENDING_SELECTOR) { - updateConfigSelector(null); + // Apply Default Service Config if initial name resolution fails. + if (defaultServiceConfig != null) { + updateConfigSelector(defaultServiceConfig.getDefaultConfigSelector()); + lastServiceConfig = defaultServiceConfig; + channelLogger.log(ChannelLogLevel.ERROR, + "Initial Name Resolution error, using default service config"); + } else { + updateConfigSelector(null); + } } } @@ -1197,7 +1177,7 @@ private void maybeTerminateChannel() { if (terminated) { return; } - if (shutdown.get() && subchannels.isEmpty() && oobChannels.isEmpty()) { + if (shutdown.get() && subchannels.isEmpty()) { channelLogger.log(ChannelLogLevel.INFO, "Terminated"); channelz.removeRootChannel(this); executorPool.returnObject(executor); @@ -1211,15 +1191,7 @@ private void maybeTerminateChannel() { } } - // Must be called from syncContext - private void handleInternalSubchannelState(ConnectivityStateInfo newState) { - if (newState.getState() == TRANSIENT_FAILURE || newState.getState() == IDLE) { - refreshNameResolution(); - } - } - @Override - @SuppressWarnings("deprecation") public ConnectivityState getState(boolean requestConnection) { ConnectivityState savedChannelState = channelStateManager.getState(); if (requestConnection && savedChannelState == IDLE) { @@ -1227,9 +1199,6 @@ final class RequestConnection implements Runnable { @Override public void run() { exitIdleMode(); - if (subchannelPicker != null) { - subchannelPicker.requestConnection(); - } if (lbHelper != null) { lbHelper.lb.requestConnection(); } @@ -1267,9 +1236,6 @@ public void run() { for (InternalSubchannel subchannel : subchannels) { subchannel.resetConnectBackoff(); } - for (OobChannel oobChannel : oobChannels) { - oobChannel.resetConnectBackoff(); - } } } @@ -1376,7 +1342,7 @@ void remove(RetriableStream retriableStream) { } private final class LbHelperImpl extends LoadBalancer.Helper { - AutoConfiguredLoadBalancer lb; + LoadBalancer lb; @Override public AbstractSubchannel createSubchannel(CreateSubchannelArgs args) { @@ -1392,24 +1358,18 @@ public void updateBalancingState( syncContext.throwIfNotInThisSynchronizationContext(); checkNotNull(newState, "newState"); checkNotNull(newPicker, "newPicker"); - final class UpdateBalancingState implements Runnable { - @Override - public void run() { - if (LbHelperImpl.this != lbHelper) { - return; - } - updateSubchannelPicker(newPicker); - // It's not appropriate to report SHUTDOWN state from lb. - // Ignore the case of newState == SHUTDOWN for now. - if (newState != SHUTDOWN) { - channelLogger.log( - ChannelLogLevel.INFO, "Entering {0} state with picker: {1}", newState, newPicker); - channelStateManager.gotoState(newState); - } - } - } - syncContext.execute(new UpdateBalancingState()); + if (LbHelperImpl.this != lbHelper || panicMode) { + return; + } + updateSubchannelPicker(newPicker); + // It's not appropriate to report SHUTDOWN state from lb. + // Ignore the case of newState == SHUTDOWN for now. + if (newState != SHUTDOWN) { + channelLogger.log( + ChannelLogLevel.INFO, "Entering {0} state with picker: {1}", newState, newPicker); + channelStateManager.gotoState(newState); + } } @Override @@ -1433,84 +1393,28 @@ public ManagedChannel createOobChannel(EquivalentAddressGroup addressGroup, Stri @Override public ManagedChannel createOobChannel(List addressGroup, String authority) { - // TODO(ejona): can we be even stricter? Like terminating? - checkState(!terminated, "Channel is terminated"); - long oobChannelCreationTime = timeProvider.currentTimeNanos(); - InternalLogId oobLogId = InternalLogId.allocate("OobChannel", /*details=*/ null); - InternalLogId subchannelLogId = - InternalLogId.allocate("Subchannel-OOB", /*details=*/ authority); - ChannelTracer oobChannelTracer = - new ChannelTracer( - oobLogId, maxTraceEvents, oobChannelCreationTime, - "OobChannel for " + addressGroup); - final OobChannel oobChannel = new OobChannel( - authority, balancerRpcExecutorPool, oobTransportFactory.getScheduledExecutorService(), - syncContext, callTracerFactory.create(), oobChannelTracer, channelz, timeProvider); - channelTracer.reportEvent(new ChannelTrace.Event.Builder() - .setDescription("Child OobChannel created") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(oobChannelCreationTime) - .setChannelRef(oobChannel) - .build()); - ChannelTracer subchannelTracer = - new ChannelTracer(subchannelLogId, maxTraceEvents, oobChannelCreationTime, - "Subchannel for " + addressGroup); - ChannelLogger subchannelLogger = new ChannelLoggerImpl(subchannelTracer, timeProvider); - final class ManagedOobChannelCallback extends InternalSubchannel.Callback { - @Override - void onTerminated(InternalSubchannel is) { - oobChannels.remove(oobChannel); - channelz.removeSubchannel(is); - oobChannel.handleSubchannelTerminated(); - maybeTerminateChannel(); - } - - @Override - void onStateChange(InternalSubchannel is, ConnectivityStateInfo newState) { - // TODO(chengyuanzhang): change to let LB policies explicitly manage OOB channel's - // state and refresh name resolution if necessary. - handleInternalSubchannelState(newState); - oobChannel.handleSubchannelStateChange(newState); - } - } - - final InternalSubchannel internalSubchannel = new InternalSubchannel( - addressGroup, - authority, userAgent, backoffPolicyProvider, oobTransportFactory, - oobTransportFactory.getScheduledExecutorService(), stopwatchSupplier, syncContext, - // All callback methods are run from syncContext - new ManagedOobChannelCallback(), - channelz, - callTracerFactory.create(), - subchannelTracer, - subchannelLogId, - subchannelLogger, - transportFilters); - oobChannelTracer.reportEvent(new ChannelTrace.Event.Builder() - .setDescription("Child Subchannel created") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(oobChannelCreationTime) - .setSubchannelRef(internalSubchannel) - .build()); - channelz.addSubchannel(oobChannel); - channelz.addSubchannel(internalSubchannel); - oobChannel.setSubchannel(internalSubchannel); - final class AddOobChannel implements Runnable { - @Override - public void run() { - if (terminating) { - oobChannel.shutdown(); - } - if (!terminated) { - // If channel has not terminated, it will track the subchannel and block termination - // for it. - oobChannels.add(oobChannel); - } - } - } - - syncContext.execute(new AddOobChannel()); - return oobChannel; + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + OobNameResolverProvider resolverProvider = + new OobNameResolverProvider(authority, addressGroup, syncContext); + nameResolverRegistry.register(resolverProvider); + // We could use a hard-coded target, as the name resolver won't actually use this string. + // However, that would make debugging less clear, as we use the target to identify the + // channel. + String target; + try { + target = new URI("oob", "", "/" + authority, null, null).toString(); + } catch (URISyntaxException ex) { + // Any special characters in the path will be percent encoded. So this should be impossible. + throw new AssertionError(ex); + } + ManagedChannel delegate = createResolvingOobChannelBuilder( + target, new DefaultChannelCreds(), nameResolverRegistry) + // TODO(zdapeng): executors should not outlive the parent channel. + .executor(balancerRpcExecutorHolder.getExecutor()) + .idleTimeout(Integer.MAX_VALUE, TimeUnit.SECONDS) + .disableRetry() + .build(); + return new OobChannel(delegate, resolverProvider); } @Deprecated @@ -1522,11 +1426,17 @@ public ManagedChannelBuilder createResolvingOobChannelBuilder(String target) .overrideAuthority(getAuthority()); } - // TODO(creamsoup) prevent main channel to shutdown if oob channel is not terminated - // TODO(zdapeng) register the channel as a subchannel of the parent channel in channelz. @Override public ManagedChannelBuilder createResolvingOobChannelBuilder( final String target, final ChannelCredentials channelCreds) { + return createResolvingOobChannelBuilder(target, channelCreds, nameResolverRegistry); + } + + // TODO(creamsoup) prevent main channel to shutdown if oob channel is not terminated + // TODO(zdapeng) register the channel as a subchannel of the parent channel in channelz. + private ManagedChannelBuilder createResolvingOobChannelBuilder( + final String target, final ChannelCredentials channelCreds, + NameResolverRegistry nameResolverRegistry) { checkNotNull(channelCreds, "channelCreds"); final class ResolvingOobChannelBuilder @@ -1574,7 +1484,6 @@ protected ManagedChannelBuilder delegate() { checkState(!terminated, "Channel is terminated"); - @SuppressWarnings("deprecation") ResolvingOobChannelBuilder builder = new ResolvingOobChannelBuilder(); return builder @@ -1660,6 +1569,19 @@ public ChannelCredentials withoutBearerTokens() { } } + static final class OobChannel extends ForwardingManagedChannel { + private final OobNameResolverProvider resolverProvider; + + public OobChannel(ManagedChannel delegate, OobNameResolverProvider resolverProvider) { + super(delegate); + this.resolverProvider = checkNotNull(resolverProvider, "resolverProvider"); + } + + public void updateAddresses(List eags) { + resolverProvider.updateAddresses(eags); + } + } + final class NameResolverListener extends NameResolver.Listener2 { final LbHelperImpl helper; final NameResolver resolver; @@ -1671,18 +1593,7 @@ final class NameResolverListener extends NameResolver.Listener2 { @Override public void onResult(final ResolutionResult resolutionResult) { - final class NamesResolved implements Runnable { - - @Override - public void run() { - Status status = onResult2(resolutionResult); - ResolutionResultListener resolutionResultListener = resolutionResult.getAttributes() - .get(RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY); - resolutionResultListener.resolutionAttempted(status); - } - } - - syncContext.execute(new NamesResolved()); + syncContext.execute(() -> onResult2(resolutionResult)); } @SuppressWarnings("ReferenceEquality") @@ -1693,7 +1604,13 @@ public Status onResult2(final ResolutionResult resolutionResult) { return Status.OK; } - List servers = resolutionResult.getAddresses(); + StatusOr> serversOrError = + resolutionResult.getAddressesOrError(); + if (!serversOrError.hasValue()) { + handleErrorInSyncContext(serversOrError.getStatus()); + return serversOrError.getStatus(); + } + List servers = serversOrError.getValue(); channelLogger.log( ChannelLogLevel.DEBUG, "Resolved address: {0}, config={1}", @@ -1701,10 +1618,10 @@ public Status onResult2(final ResolutionResult resolutionResult) { resolutionResult.getAttributes()); if (lastResolutionState != ResolutionState.SUCCESS) { - channelLogger.log(ChannelLogLevel.INFO, "Address resolved: {0}", servers); + channelLogger.log(ChannelLogLevel.INFO, "Address resolved: {0}", + servers); lastResolutionState = ResolutionState.SUCCESS; } - ConfigOrError configOrError = resolutionResult.getServiceConfig(); InternalConfigSelector resolvedConfigSelector = resolutionResult.getAttributes().get(InternalConfigSelector.KEY); @@ -1780,7 +1697,7 @@ public Status onResult2(final ResolutionResult resolutionResult) { } try { - // TODO(creamsoup): when `servers` is empty and lastResolutionStateCopy == SUCCESS + // TODO(creamsoup): when `serversOrError` is empty and lastResolutionStateCopy == SUCCESS // and lbNeedAddress, it shouldn't call the handleServiceConfigUpdate. But, // lbNeedAddress is not deterministic serviceConfigUpdated = true; @@ -1806,12 +1723,13 @@ public Status onResult2(final ResolutionResult resolutionResult) { } Attributes attributes = attrBuilder.build(); - return helper.lb.tryAcceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers) - .setAttributes(attributes) - .setLoadBalancingPolicyConfig(effectiveServiceConfig.getLoadBalancingConfig()) - .build()); + ResolvedAddresses.Builder resolvedAddresses = ResolvedAddresses.newBuilder() + .setAddresses(serversOrError.getValue()) + .setAttributes(attributes) + .setLoadBalancingPolicyConfig(effectiveServiceConfig.getLoadBalancingConfig()); + Status addressAcceptanceStatus = helper.lb.acceptResolvedAddresses( + resolvedAddresses.build()); + return addressAcceptanceStatus; } return Status.OK; } @@ -1907,7 +1825,7 @@ void onNotInUse(InternalSubchannel is) { } final InternalSubchannel internalSubchannel = new InternalSubchannel( - args.getAddresses(), + args, authority(), userAgent, backoffPolicyProvider, @@ -1921,7 +1839,8 @@ void onNotInUse(InternalSubchannel is) { subchannelTracer, subchannelLogId, subchannelLogger, - transportFilters); + transportFilters, target, + lbHelper.getMetricRecorder()); channelTracer.reportEvent(new ChannelTrace.Event.Builder() .setDescription("Child Subchannel started") @@ -1993,6 +1912,9 @@ public void run() { public void requestConnection() { syncContext.throwIfNotInThisSynchronizationContext(); checkState(started, "not started"); + if (shutdown) { + return; + } subchannel.obtainActiveTransport(); } @@ -2075,7 +1997,7 @@ public String toString() { */ private final class DelayedTransportListener implements ManagedClientTransport.Listener { @Override - public void transportShutdown(Status s) { + public void transportShutdown(Status s, DisconnectError e) { checkState(shutdown.get(), "Channel must have been shut down"); } diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index 7da9125087e..128c929ec0e 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.internal.UriWrapper.wrap; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; @@ -37,6 +38,7 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.InternalChannelz; import io.grpc.InternalConfiguratorRegistry; +import io.grpc.InternalFeatureFlags; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.MethodDescriptor; @@ -45,6 +47,8 @@ import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; import io.grpc.ProxyDetector; +import io.grpc.StatusOr; +import io.grpc.Uri; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.SocketAddress; @@ -54,6 +58,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.IdentityHashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -158,6 +163,8 @@ public static ManagedChannelBuilder forTarget(String target) { final ChannelCredentials channelCredentials; @Nullable final CallCredentials callCredentials; + @Nullable + IdentityHashMap, Object> nameResolverCustomArgs; @Nullable private final SocketAddress directServerAddress; @@ -612,6 +619,24 @@ private static List checkListEntryTypes(List list) { return Collections.unmodifiableList(parsedList); } + @Override + public ManagedChannelImplBuilder setNameResolverArg(NameResolver.Args.Key key, X value) { + if (nameResolverCustomArgs == null) { + nameResolverCustomArgs = new IdentityHashMap<>(); + } + nameResolverCustomArgs.put(checkNotNull(key, "key"), checkNotNull(value, "value")); + return this; + } + + @SuppressWarnings("unchecked") // This cast is safe because of setNameResolverArg()'s signature. + void copyAllNameResolverCustomArgsTo(NameResolver.Args.Builder dest) { + if (nameResolverCustomArgs != null) { + for (Map.Entry, Object> entry : nameResolverCustomArgs.entrySet()) { + dest.setArg((NameResolver.Args.Key) entry.getKey(), entry.getValue()); + } + } + } + @Override public ManagedChannelImplBuilder disableServiceConfigLookUp() { this.lookUpServiceConfig = false; @@ -696,8 +721,11 @@ protected ManagedChannelImplBuilder addMetricSink(MetricSink metricSink) { public ManagedChannel build() { ClientTransportFactory clientTransportFactory = clientTransportFactoryBuilder.buildClientTransportFactory(); - ResolvedNameResolver resolvedResolver = getNameResolverProvider( - target, nameResolverRegistry, clientTransportFactory.getSupportedSocketAddressTypes()); + ResolvedNameResolver resolvedResolver = + InternalFeatureFlags.getRfc3986UrisEnabled() + ? getNameResolverProviderRfc3986(target, nameResolverRegistry) + : getNameResolverProvider(target, nameResolverRegistry); + resolvedResolver.checkAddressTypes(clientTransportFactory.getSupportedSocketAddressTypes()); return new ManagedChannelOrphanWrapper(new ManagedChannelImpl( this, clientTransportFactory, @@ -715,18 +743,16 @@ public ManagedChannel build() { // TODO(zdapeng): FIX IT @VisibleForTesting List getEffectiveInterceptors(String computedTarget) { - List effectiveInterceptors = new ArrayList<>(this.interceptors); - for (int i = 0; i < effectiveInterceptors.size(); i++) { - if (!(effectiveInterceptors.get(i) instanceof InterceptorFactoryWrapper)) { - continue; - } - InterceptorFactory factory = - ((InterceptorFactoryWrapper) effectiveInterceptors.get(i)).factory; - ClientInterceptor interceptor = factory.newInterceptor(computedTarget); - if (interceptor == null) { - throw new NullPointerException("Factory returned null interceptor: " + factory); + List effectiveInterceptors = new ArrayList<>(this.interceptors.size()); + for (ClientInterceptor interceptor : this.interceptors) { + if (interceptor instanceof InterceptorFactoryWrapper) { + InterceptorFactory factory = ((InterceptorFactoryWrapper) interceptor).factory; + interceptor = factory.newInterceptor(computedTarget); + if (interceptor == null) { + throw new NullPointerException("Factory returned null interceptor: " + factory); + } } - effectiveInterceptors.set(i, interceptor); + effectiveInterceptors.add(interceptor); } boolean disableImplicitCensus = InternalConfiguratorRegistry.wasSetConfiguratorsCalled(); @@ -739,7 +765,7 @@ List getEffectiveInterceptors(String computedTarget) { if (GET_CLIENT_INTERCEPTOR_METHOD != null) { try { statsInterceptor = - (ClientInterceptor) GET_CLIENT_INTERCEPTOR_METHOD + (ClientInterceptor) GET_CLIENT_INTERCEPTOR_METHOD .invoke( null, recordStartedRpcs, @@ -794,19 +820,32 @@ int getDefaultPort() { @VisibleForTesting static class ResolvedNameResolver { - public final URI targetUri; + public final UriWrapper targetUri; public final NameResolverProvider provider; - public ResolvedNameResolver(URI targetUri, NameResolverProvider provider) { + public ResolvedNameResolver(UriWrapper targetUri, NameResolverProvider provider) { this.targetUri = checkNotNull(targetUri, "targetUri"); this.provider = checkNotNull(provider, "provider"); } + + void checkAddressTypes( + Collection> channelTransportSocketAddressTypes) { + if (channelTransportSocketAddressTypes != null) { + Collection> nameResolverSocketAddressTypes = + provider.getProducedSocketAddressTypes(); + if (!channelTransportSocketAddressTypes.containsAll(nameResolverSocketAddressTypes)) { + throw new IllegalArgumentException( + String.format( + "Address types of NameResolver '%s' for '%s' not supported by transport", + provider.getDefaultScheme(), targetUri)); + } + } + } } @VisibleForTesting static ResolvedNameResolver getNameResolverProvider( - String target, NameResolverRegistry nameResolverRegistry, - Collection> channelTransportSocketAddressTypes) { + String target, NameResolverRegistry nameResolverRegistry) { // Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending // "dns:///". NameResolverProvider provider = null; @@ -842,17 +881,49 @@ static ResolvedNameResolver getNameResolverProvider( target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : "")); } - if (channelTransportSocketAddressTypes != null) { - Collection> nameResolverSocketAddressTypes - = provider.getProducedSocketAddressTypes(); - if (!channelTransportSocketAddressTypes.containsAll(nameResolverSocketAddressTypes)) { - throw new IllegalArgumentException(String.format( - "Address types of NameResolver '%s' for '%s' not supported by transport", - targetUri.getScheme(), target)); - } + return new ResolvedNameResolver(wrap(targetUri), provider); + } + + @VisibleForTesting + static ResolvedNameResolver getNameResolverProviderRfc3986( + String target, NameResolverRegistry nameResolverRegistry) { + // Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending + // "dns:///". + NameResolverProvider provider = null; + Uri targetUri = null; + StringBuilder uriSyntaxErrors = new StringBuilder(); + try { + targetUri = Uri.parse(target); + } catch (URISyntaxException e) { + // Can happen with ip addresses like "[::1]:1234" or 127.0.0.1:1234. + uriSyntaxErrors.append(e.getMessage()); + } + if (targetUri != null) { + // For "localhost:8080" this would likely cause provider to be null, because "localhost" is + // parsed as the scheme. Will hit the next case and try "dns:///localhost:8080". + provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); + } + + if (provider == null && !URI_PATTERN.matcher(target).matches()) { + // It doesn't look like a URI target. Maybe it's an authority string. Try with the default + // scheme from the registry. + targetUri = + Uri.newBuilder() + .setScheme(nameResolverRegistry.getDefaultScheme()) + .setHost("") + .setPath("/" + target) + .build(); + provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); + } + + if (provider == null) { + throw new IllegalArgumentException( + String.format( + "Could not find a NameResolverProvider for %s%s", + target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : "")); } - return new ResolvedNameResolver(targetUri, provider); + return new ResolvedNameResolver(wrap(targetUri), provider); } private static class DirectAddressNameResolverProvider extends NameResolverProvider { @@ -877,9 +948,11 @@ public String getServiceAuthority() { @Override public void start(Listener2 listener) { - listener.onResult( + listener.onResult2( ResolutionResult.newBuilder() - .setAddresses(Collections.singletonList(new EquivalentAddressGroup(address))) + .setAddressesOrError( + StatusOr.fromValue( + Collections.singletonList(new EquivalentAddressGroup(address)))) .setAttributes(Attributes.EMPTY) .build()); } diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java b/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java index eac9b64d9db..790d5bd297f 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java @@ -63,12 +63,20 @@ final class ManagedChannelOrphanWrapper extends ForwardingManagedChannel { @Override public ManagedChannel shutdown() { phantom.clearSafely(); + // This dummy check prevents the JIT from collecting 'this' too early + if (this.getClass() == null) { + throw new AssertionError(); + } return super.shutdown(); } @Override public ManagedChannel shutdownNow() { phantom.clearSafely(); + // This dummy check prevents the JIT from collecting 'this' too early + if (this.getClass() == null) { + throw new AssertionError(); + } return super.shutdownNow(); } @@ -151,8 +159,9 @@ static int cleanQueue(ReferenceQueue refqueue) { int orphanedChannels = 0; while ((ref = (ManagedChannelReference) refqueue.poll()) != null) { RuntimeException maybeAllocationSite = ref.allocationSite.get(); + boolean wasShutdown = ref.shutdown.get(); ref.clearInternal(); // technically the reference is gone already. - if (!ref.shutdown.get()) { + if (!wasShutdown) { orphanedChannels++; Level level = Level.SEVERE; if (logger.isLoggable(level)) { diff --git a/core/src/main/java/io/grpc/internal/ManagedClientTransport.java b/core/src/main/java/io/grpc/internal/ManagedClientTransport.java index 5f8fe52ef6b..8350a005409 100644 --- a/core/src/main/java/io/grpc/internal/ManagedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ManagedClientTransport.java @@ -16,9 +16,9 @@ package io.grpc.internal; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.Status; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; @@ -77,8 +77,9 @@ interface Listener { *

This is called exactly once, and must be called prior to {@link #transportTerminated}. * * @param s the reason for the shutdown. + * @param e the disconnect error. */ - void transportShutdown(Status s); + void transportShutdown(Status s, DisconnectError e); /** * The transport completed shutting down. All resources have been released. All streams have diff --git a/core/src/main/java/io/grpc/internal/MessageDeframer.java b/core/src/main/java/io/grpc/internal/MessageDeframer.java index c8b250c2143..13a01efec0a 100644 --- a/core/src/main/java/io/grpc/internal/MessageDeframer.java +++ b/core/src/main/java/io/grpc/internal/MessageDeframer.java @@ -406,7 +406,8 @@ private void processBody() { // There is no reliable way to get the uncompressed size per message when it's compressed, // because the uncompressed bytes are provided through an InputStream whose total size is // unknown until all bytes are read, and we don't know when it happens. - statsTraceCtx.inboundMessageRead(currentMessageSeqNo, inboundBodyWireSize, -1); + statsTraceCtx.inboundMessageRead(currentMessageSeqNo, inboundBodyWireSize, + (compressedFlag || fullStreamDecompressor != null) ? -1 : inboundBodyWireSize); inboundBodyWireSize = 0; InputStream stream = compressedFlag ? getCompressedBody() : getUncompressedBody(); nextFrame.touch(); diff --git a/core/src/main/java/io/grpc/internal/MessageFramer.java b/core/src/main/java/io/grpc/internal/MessageFramer.java index 5e75fa2e6fe..8b5ccb864a4 100644 --- a/core/src/main/java/io/grpc/internal/MessageFramer.java +++ b/core/src/main/java/io/grpc/internal/MessageFramer.java @@ -75,6 +75,10 @@ void deliverFrame( // effectively final. Can only be set once. private int maxOutboundMessageSize = NO_MAX_OUTBOUND_MESSAGE_SIZE; private WritableBuffer buffer; + /** + * if > 0 - the number of bytes to allocate for the current known-length message. + */ + private int knownLengthPendingAllocation; private Compressor compressor = Codec.Identity.NONE; private boolean messageCompression = true; private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter(); @@ -222,9 +226,7 @@ private int writeKnownLengthUncompressed(InputStream message, int messageLength) headerScratch.put(UNCOMPRESSED).putInt(messageLength); // Allocate the initial buffer chunk based on frame header + payload length. // Note that the allocator may allocate a buffer larger or smaller than this length - if (buffer == null) { - buffer = bufferAllocator.allocate(headerScratch.position() + messageLength); - } + knownLengthPendingAllocation = HEADER_LENGTH + messageLength; writeRaw(headerScratch.array(), 0, headerScratch.position()); return writeToOutputStream(message, outputStreamAdapter); } @@ -288,8 +290,9 @@ private void writeRaw(byte[] b, int off, int len) { commitToSink(false, false); } if (buffer == null) { - // Request a buffer allocation using the message length as a hint. - buffer = bufferAllocator.allocate(len); + checkState(knownLengthPendingAllocation > 0, "knownLengthPendingAllocation reached 0"); + buffer = bufferAllocator.allocate(knownLengthPendingAllocation); + knownLengthPendingAllocation -= min(knownLengthPendingAllocation, buffer.writableBytes()); } int toWrite = min(len, buffer.writableBytes()); buffer.write(b, off, toWrite); @@ -388,6 +391,8 @@ public void write(byte[] b, int off, int len) { * {@link OutputStream}. */ private final class BufferChainOutputStream extends OutputStream { + private static final int FIRST_BUFFER_SIZE = 4096; + private final List bufferList = new ArrayList<>(); private WritableBuffer current; @@ -397,7 +402,7 @@ private final class BufferChainOutputStream extends OutputStream { * {@link #write(byte[], int, int)}. */ @Override - public void write(int b) throws IOException { + public void write(int b) { if (current != null && current.writableBytes() > 0) { current.write((byte)b); return; @@ -410,7 +415,7 @@ public void write(int b) throws IOException { public void write(byte[] b, int off, int len) { if (current == null) { // Request len bytes initially from the allocator, it may give us more. - current = bufferAllocator.allocate(len); + current = bufferAllocator.allocate(Math.max(FIRST_BUFFER_SIZE, len)); bufferList.add(current); } while (len > 0) { diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index 12cab15053f..166f97b78f5 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.CallCredentials.MetadataApplier; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; @@ -28,7 +29,6 @@ import io.grpc.MethodDescriptor; import io.grpc.Status; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; final class MetadataApplierImpl extends MetadataApplier { private final ClientTransport transport; @@ -120,7 +120,7 @@ ClientStream returnStream() { synchronized (lock) { if (returnedStream == null) { // apply() has not been called, needs to buffer the requests. - delayedStream = new DelayedStream(); + delayedStream = new DelayedStream("call_credentials"); return returnedStream = delayedStream; } else { return returnedStream; diff --git a/core/src/main/java/io/grpc/internal/MetricRecorderImpl.java b/core/src/main/java/io/grpc/internal/MetricRecorderImpl.java index 452b1c5df07..6a12a38d677 100644 --- a/core/src/main/java/io/grpc/internal/MetricRecorderImpl.java +++ b/core/src/main/java/io/grpc/internal/MetricRecorderImpl.java @@ -20,12 +20,14 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import io.grpc.CallbackMetricInstrument; import io.grpc.DoubleCounterMetricInstrument; import io.grpc.DoubleHistogramMetricInstrument; import io.grpc.LongCounterMetricInstrument; import io.grpc.LongGaugeMetricInstrument; import io.grpc.LongHistogramMetricInstrument; +import io.grpc.LongUpDownCounterMetricInstrument; import io.grpc.MetricInstrument; import io.grpc.MetricInstrumentRegistry; import io.grpc.MetricRecorder; @@ -48,7 +50,7 @@ final class MetricRecorderImpl implements MetricRecorder { @VisibleForTesting MetricRecorderImpl(List metricSinks, MetricInstrumentRegistry registry) { - this.metricSinks = metricSinks; + this.metricSinks = ImmutableList.copyOf(metricSinks); this.registry = registry; } @@ -82,7 +84,7 @@ public void addDoubleCounter(DoubleCounterMetricInstrument metricInstrument, dou * Records a long counter value. * * @param metricInstrument the {@link LongCounterMetricInstrument} to record. - * @param value the value to record. + * @param value the value to record. Must be non-negative. * @param requiredLabelValues the required label values for the metric. * @param optionalLabelValues the optional label values for the metric. */ @@ -103,6 +105,32 @@ public void addLongCounter(LongCounterMetricInstrument metricInstrument, long va } } + /** + * Adds a long up down counter value. + * + * @param metricInstrument the {@link io.grpc.LongUpDownCounterMetricInstrument} to record. + * @param value the value to record. May be positive, negative or zero. + * @param requiredLabelValues the required label values for the metric. + * @param optionalLabelValues the optional label values for the metric. + */ + @Override + public void addLongUpDownCounter(LongUpDownCounterMetricInstrument metricInstrument, long value, + List requiredLabelValues, + List optionalLabelValues) { + MetricRecorder.super.addLongUpDownCounter(metricInstrument, value, requiredLabelValues, + optionalLabelValues); + for (MetricSink sink : metricSinks) { + int measuresSize = sink.getMeasuresSize(); + if (measuresSize <= metricInstrument.getIndex()) { + // Measures may need updating in two cases: + // 1. When the sink is initially created with an empty list of measures. + // 2. When new metric instruments are registered, requiring the sink to accommodate them. + sink.updateMeasures(registry.getMetricInstruments()); + } + sink.addLongUpDownCounter(metricInstrument, value, requiredLabelValues, optionalLabelValues); + } + } + /** * Records a double histogram value. * diff --git a/core/src/main/java/io/grpc/internal/MigratingThreadDeframer.java b/core/src/main/java/io/grpc/internal/MigratingThreadDeframer.java index c3342556c9f..e4f499ab483 100644 --- a/core/src/main/java/io/grpc/internal/MigratingThreadDeframer.java +++ b/core/src/main/java/io/grpc/internal/MigratingThreadDeframer.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Decompressor; import io.perfmark.Link; import io.perfmark.PerfMark; @@ -26,7 +27,6 @@ import java.io.InputStream; import java.util.ArrayDeque; import java.util.Queue; -import javax.annotation.concurrent.GuardedBy; /** * A deframer that moves decoding between the transport and app threads based on which is more diff --git a/core/src/main/java/io/grpc/internal/NoopClientStream.java b/core/src/main/java/io/grpc/internal/NoopClientStream.java index d44170f69fa..d77d72a5412 100644 --- a/core/src/main/java/io/grpc/internal/NoopClientStream.java +++ b/core/src/main/java/io/grpc/internal/NoopClientStream.java @@ -45,7 +45,9 @@ public Attributes getAttributes() { public void request(int numMessages) {} @Override - public void writeMessage(InputStream message) {} + public void writeMessage(InputStream message) { + GrpcUtil.closeQuietly(message); + } @Override public void flush() {} diff --git a/core/src/main/java/io/grpc/internal/NoopSslSession.java b/core/src/main/java/io/grpc/internal/NoopSslSession.java new file mode 100644 index 00000000000..9a79d281ad5 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/NoopSslSession.java @@ -0,0 +1,132 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import java.security.Principal; +import java.security.cert.Certificate; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSessionContext; + +/** A no-op ssl session, to facilitate overriding only the required methods in specific + * implementations. + */ +public class NoopSslSession implements SSLSession { + @Override + public byte[] getId() { + return new byte[0]; + } + + @Override + public SSLSessionContext getSessionContext() { + return null; + } + + @Override + @SuppressWarnings("deprecation") + public javax.security.cert.X509Certificate[] getPeerCertificateChain() { + throw new UnsupportedOperationException("This method is deprecated and marked for removal. " + + "Use the getPeerCertificates() method instead."); + } + + @Override + public long getCreationTime() { + return 0; + } + + @Override + public long getLastAccessedTime() { + return 0; + } + + @Override + public void invalidate() { + } + + @Override + public boolean isValid() { + return false; + } + + @Override + public void putValue(String s, Object o) { + } + + @Override + public Object getValue(String s) { + return null; + } + + @Override + public void removeValue(String s) { + } + + @Override + public String[] getValueNames() { + return new String[0]; + } + + @Override + public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { + return new Certificate[0]; + } + + @Override + public Certificate[] getLocalCertificates() { + return new Certificate[0]; + } + + @Override + public Principal getPeerPrincipal() throws SSLPeerUnverifiedException { + return null; + } + + @Override + public Principal getLocalPrincipal() { + return null; + } + + @Override + public String getCipherSuite() { + return null; + } + + @Override + public String getProtocol() { + return null; + } + + @Override + public String getPeerHost() { + return null; + } + + @Override + public int getPeerPort() { + return 0; + } + + @Override + public int getPacketBufferSize() { + return 0; + } + + @Override + public int getApplicationBufferSize() { + return 0; + } +} diff --git a/core/src/main/java/io/grpc/internal/OobChannel.java b/core/src/main/java/io/grpc/internal/OobChannel.java deleted file mode 100644 index 01ef457460f..00000000000 --- a/core/src/main/java/io/grpc/internal/OobChannel.java +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Copyright 2016 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.internal; - -import static com.google.common.base.Preconditions.checkNotNull; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.MoreObjects; -import com.google.common.base.Preconditions; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; -import io.grpc.Attributes; -import io.grpc.CallOptions; -import io.grpc.ClientCall; -import io.grpc.ClientStreamTracer; -import io.grpc.ConnectivityState; -import io.grpc.ConnectivityStateInfo; -import io.grpc.Context; -import io.grpc.EquivalentAddressGroup; -import io.grpc.InternalChannelz; -import io.grpc.InternalChannelz.ChannelStats; -import io.grpc.InternalChannelz.ChannelTrace; -import io.grpc.InternalInstrumented; -import io.grpc.InternalLogId; -import io.grpc.InternalWithLogId; -import io.grpc.LoadBalancer; -import io.grpc.LoadBalancer.PickResult; -import io.grpc.LoadBalancer.PickSubchannelArgs; -import io.grpc.LoadBalancer.Subchannel; -import io.grpc.LoadBalancer.SubchannelPicker; -import io.grpc.ManagedChannel; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.internal.ClientCallImpl.ClientStreamProvider; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import java.util.logging.Level; -import java.util.logging.Logger; -import javax.annotation.concurrent.ThreadSafe; - -/** - * A ManagedChannel backed by a single {@link InternalSubchannel} and used for {@link LoadBalancer} - * to its own RPC needs. - */ -@ThreadSafe -final class OobChannel extends ManagedChannel implements InternalInstrumented { - private static final Logger log = Logger.getLogger(OobChannel.class.getName()); - - private InternalSubchannel subchannel; - private AbstractSubchannel subchannelImpl; - private SubchannelPicker subchannelPicker; - - private final InternalLogId logId; - private final String authority; - private final DelayedClientTransport delayedTransport; - private final InternalChannelz channelz; - private final ObjectPool executorPool; - private final Executor executor; - private final ScheduledExecutorService deadlineCancellationExecutor; - private final CountDownLatch terminatedLatch = new CountDownLatch(1); - private volatile boolean shutdown; - private final CallTracer channelCallsTracer; - private final ChannelTracer channelTracer; - private final TimeProvider timeProvider; - - private final ClientStreamProvider transportProvider = new ClientStreamProvider() { - @Override - public ClientStream newStream(MethodDescriptor method, - CallOptions callOptions, Metadata headers, Context context) { - ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( - callOptions, headers, 0, /* isTransparentRetry= */ false); - Context origContext = context.attach(); - // delayed transport's newStream() always acquires a lock, but concurrent performance doesn't - // matter here because OOB communication should be sparse, and it's not on application RPC's - // critical path. - try { - return delayedTransport.newStream(method, headers, callOptions, tracers); - } finally { - context.detach(origContext); - } - } - }; - - OobChannel( - String authority, ObjectPool executorPool, - ScheduledExecutorService deadlineCancellationExecutor, SynchronizationContext syncContext, - CallTracer callsTracer, ChannelTracer channelTracer, InternalChannelz channelz, - TimeProvider timeProvider) { - this.authority = checkNotNull(authority, "authority"); - this.logId = InternalLogId.allocate(getClass(), authority); - this.executorPool = checkNotNull(executorPool, "executorPool"); - this.executor = checkNotNull(executorPool.getObject(), "executor"); - this.deadlineCancellationExecutor = checkNotNull( - deadlineCancellationExecutor, "deadlineCancellationExecutor"); - this.delayedTransport = new DelayedClientTransport(executor, syncContext); - this.channelz = Preconditions.checkNotNull(channelz); - this.delayedTransport.start(new ManagedClientTransport.Listener() { - @Override - public void transportShutdown(Status s) { - // Don't care - } - - @Override - public void transportTerminated() { - subchannelImpl.shutdown(); - } - - @Override - public void transportReady() { - // Don't care - } - - @Override - public Attributes filterTransport(Attributes attributes) { - return attributes; - } - - @Override - public void transportInUse(boolean inUse) { - // Don't care - } - }); - this.channelCallsTracer = callsTracer; - this.channelTracer = checkNotNull(channelTracer, "channelTracer"); - this.timeProvider = checkNotNull(timeProvider, "timeProvider"); - } - - // Must be called only once, right after the OobChannel is created. - void setSubchannel(final InternalSubchannel subchannel) { - log.log(Level.FINE, "[{0}] Created with [{1}]", new Object[] {this, subchannel}); - this.subchannel = subchannel; - subchannelImpl = new AbstractSubchannel() { - @Override - public void shutdown() { - subchannel.shutdown(Status.UNAVAILABLE.withDescription("OobChannel is shutdown")); - } - - @Override - InternalInstrumented getInstrumentedInternalSubchannel() { - return subchannel; - } - - @Override - public void requestConnection() { - subchannel.obtainActiveTransport(); - } - - @Override - public List getAllAddresses() { - return subchannel.getAddressGroups(); - } - - @Override - public Attributes getAttributes() { - return Attributes.EMPTY; - } - - @Override - public Object getInternalSubchannel() { - return subchannel; - } - }; - - final class OobSubchannelPicker extends SubchannelPicker { - final PickResult result = PickResult.withSubchannel(subchannelImpl); - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return result; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(OobSubchannelPicker.class) - .add("result", result) - .toString(); - } - } - - subchannelPicker = new OobSubchannelPicker(); - delayedTransport.reprocess(subchannelPicker); - } - - void updateAddresses(List eag) { - subchannel.updateAddresses(eag); - } - - @Override - public ClientCall newCall( - MethodDescriptor methodDescriptor, CallOptions callOptions) { - return new ClientCallImpl<>(methodDescriptor, - callOptions.getExecutor() == null ? executor : callOptions.getExecutor(), - callOptions, transportProvider, deadlineCancellationExecutor, channelCallsTracer, null); - } - - @Override - public String authority() { - return authority; - } - - @Override - public boolean isTerminated() { - return terminatedLatch.getCount() == 0; - } - - @Override - public boolean awaitTermination(long time, TimeUnit unit) throws InterruptedException { - return terminatedLatch.await(time, unit); - } - - @Override - public ConnectivityState getState(boolean requestConnectionIgnored) { - if (subchannel == null) { - return ConnectivityState.IDLE; - } - return subchannel.getState(); - } - - @Override - public ManagedChannel shutdown() { - shutdown = true; - delayedTransport.shutdown(Status.UNAVAILABLE.withDescription("OobChannel.shutdown() called")); - return this; - } - - @Override - public boolean isShutdown() { - return shutdown; - } - - @Override - public ManagedChannel shutdownNow() { - shutdown = true; - delayedTransport.shutdownNow( - Status.UNAVAILABLE.withDescription("OobChannel.shutdownNow() called")); - return this; - } - - void handleSubchannelStateChange(final ConnectivityStateInfo newState) { - channelTracer.reportEvent( - new ChannelTrace.Event.Builder() - .setDescription("Entering " + newState.getState() + " state") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(timeProvider.currentTimeNanos()) - .build()); - switch (newState.getState()) { - case READY: - case IDLE: - delayedTransport.reprocess(subchannelPicker); - break; - case TRANSIENT_FAILURE: - final class OobErrorPicker extends SubchannelPicker { - final PickResult errorResult = PickResult.withError(newState.getStatus()); - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return errorResult; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(OobErrorPicker.class) - .add("errorResult", errorResult) - .toString(); - } - } - - delayedTransport.reprocess(new OobErrorPicker()); - break; - default: - // Do nothing - } - } - - // must be run from channel executor - void handleSubchannelTerminated() { - channelz.removeSubchannel(this); - // When delayedTransport is terminated, it shuts down subchannel. Therefore, at this point - // both delayedTransport and subchannel have terminated. - executorPool.returnObject(executor); - terminatedLatch.countDown(); - } - - @VisibleForTesting - Subchannel getSubchannel() { - return subchannelImpl; - } - - InternalSubchannel getInternalSubchannel() { - return subchannel; - } - - @Override - public ListenableFuture getStats() { - final SettableFuture ret = SettableFuture.create(); - final ChannelStats.Builder builder = new ChannelStats.Builder(); - channelCallsTracer.updateBuilder(builder); - channelTracer.updateBuilder(builder); - builder - .setTarget(authority) - .setState(subchannel.getState()) - .setSubchannels(Collections.singletonList(subchannel)); - ret.set(builder.build()); - return ret; - } - - @Override - public InternalLogId getLogId() { - return logId; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("logId", logId.getId()) - .add("authority", authority) - .toString(); - } - - @Override - public void resetConnectBackoff() { - subchannel.resetConnectBackoff(); - } -} diff --git a/core/src/main/java/io/grpc/internal/OobNameResolverProvider.java b/core/src/main/java/io/grpc/internal/OobNameResolverProvider.java new file mode 100644 index 00000000000..408b92e0c84 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/OobNameResolverProvider.java @@ -0,0 +1,121 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static java.util.Objects.requireNonNull; + +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; +import io.grpc.StatusOr; +import io.grpc.SynchronizationContext; +import java.net.URI; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +/** + * A provider that is passed addresses and relays those addresses to its created resolvers. + */ +final class OobNameResolverProvider extends NameResolverProvider { + private final String authority; + private final SynchronizationContext parentSyncContext; + // Only accessed from parentSyncContext + @SuppressWarnings("JdkObsolete") // LinkedList uses O(n) memory, including after deletions + private final Collection resolvers = new LinkedList<>(); + // Only accessed from parentSyncContext + private List lastEags; + + public OobNameResolverProvider( + String authority, List eags, SynchronizationContext syncContext) { + this.authority = requireNonNull(authority, "authority"); + this.lastEags = requireNonNull(eags, "eags"); + this.parentSyncContext = requireNonNull(syncContext, "syncContext"); + } + + @Override + public String getDefaultScheme() { + return "oob"; + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; // Doesn't matter, as we expect only one provider in the registry + } + + public void updateAddresses(List eags) { + requireNonNull(eags, "eags"); + parentSyncContext.execute(() -> { + this.lastEags = eags; + for (OobNameResolver resolver : resolvers) { + resolver.updateAddresses(eags); + } + }); + } + + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return new OobNameResolver(args.getSynchronizationContext()); + } + + final class OobNameResolver extends NameResolver { + private final SynchronizationContext syncContext; + // Null before started, and after shutdown. Only accessed from syncContext + private Listener2 listener; + + public OobNameResolver(SynchronizationContext syncContext) { + this.syncContext = requireNonNull(syncContext, "syncContext"); + } + + @Override + public String getServiceAuthority() { + return authority; + } + + @Override + public void start(Listener2 listener) { + this.listener = requireNonNull(listener, "listener"); + parentSyncContext.execute(() -> { + resolvers.add(this); + updateAddresses(lastEags); + }); + } + + void updateAddresses(List eags) { + parentSyncContext.throwIfNotInThisSynchronizationContext(); + syncContext.execute(() -> { + if (listener == null) { + return; + } + listener.onResult2(ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(lastEags)) + .build()); + }); + } + + @Override + public void shutdown() { + this.listener = null; + parentSyncContext.execute(() -> resolvers.remove(this)); + } + } +} diff --git a/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java index bfa462e16e1..f8f5c94f5ba 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java @@ -24,16 +24,19 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.MoreObjects; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; import io.grpc.SynchronizationContext.ScheduledHandle; +import java.net.Inet4Address; +import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Collections; @@ -58,30 +61,41 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer { private static final Logger log = Logger.getLogger(PickFirstLeafLoadBalancer.class.getName()); @VisibleForTesting static final int CONNECTION_DELAY_INTERVAL_MS = 250; + private final boolean enableHappyEyeballs = !isSerializingRetries() + && PickFirstLoadBalancerProvider.isEnabledHappyEyeballs(); + static boolean weightedShuffling = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true); private final Helper helper; private final Map subchannels = new HashMap<>(); - private final Index addressIndex = new Index(ImmutableList.of()); + private final Index addressIndex = new Index(ImmutableList.of(), this.enableHappyEyeballs); private int numTf = 0; private boolean firstPass = true; @Nullable - private ScheduledHandle scheduleConnectionTask; + private ScheduledHandle scheduleConnectionTask = null; private ConnectivityState rawConnectivityState = IDLE; private ConnectivityState concludedState = IDLE; - private final boolean enableHappyEyeballs = - PickFirstLoadBalancerProvider.isEnabledHappyEyeballs(); private boolean notAPetiolePolicy = true; // means not under a petiole policy + private final BackoffPolicy.Provider bkoffPolProvider = new ExponentialBackoffPolicy.Provider(); + private BackoffPolicy reconnectPolicy; + @Nullable + private ScheduledHandle reconnectTask = null; + private final boolean serializingRetries = isSerializingRetries(); PickFirstLeafLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); } + static boolean isSerializingRetries() { + return GrpcUtil.getFlag("GRPC_SERIALIZE_RETRIES", false); + } + @Override public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (rawConnectivityState == SHUTDOWN) { return Status.FAILED_PRECONDITION.withDescription("Already shut down"); } - // Cache whether or not this is a petiole policy, which is based off of an address attribute + // Check whether this is a petiole policy, which is based off of an address attribute Boolean isPetiolePolicy = resolvedAddresses.getAttributes().get(IS_PETIOLE_POLICY); this.notAPetiolePolicy = isPetiolePolicy == null || !isPetiolePolicy; @@ -118,22 +132,27 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { PickFirstLeafLoadBalancerConfig config = (PickFirstLeafLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); if (config.shuffleAddressList != null && config.shuffleAddressList) { - Collections.shuffle(cleanServers, - config.randomSeed != null ? new Random(config.randomSeed) : new Random()); + cleanServers = shuffle( + cleanServers, config.randomSeed != null ? new Random(config.randomSeed) : new Random()); } } final ImmutableList newImmutableAddressGroups = - ImmutableList.builder().addAll(cleanServers).build(); - - if (rawConnectivityState == READY) { - // If the previous ready subchannel exists in new address list, - // keep this connection and don't create new subchannels + ImmutableList.copyOf(cleanServers); + + if (rawConnectivityState == READY + || (rawConnectivityState == CONNECTING + && (!enableHappyEyeballs || addressIndex.isValid()))) { + // If the previous ready (or connecting) subchannel exists in new address list, + // keep this connection and don't create new subchannels. Happy Eyeballs is excluded when + // connecting, because it allows multiple attempts simultaneously, thus is fine to start at + // the beginning. SocketAddress previousAddress = addressIndex.getCurrentAddress(); addressIndex.updateGroups(newImmutableAddressGroups); if (addressIndex.seekTo(previousAddress)) { SubchannelData subchannelData = subchannels.get(previousAddress); subchannelData.getSubchannel().updateAddresses(addressIndex.getCurrentEagAsList()); + shutdownRemovedAddresses(newImmutableAddressGroups); return Status.OK; } // Previous ready subchannel not in the new list of addresses @@ -141,26 +160,14 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { addressIndex.updateGroups(newImmutableAddressGroups); } - // remove old subchannels that were not in new address list - Set oldAddrs = new HashSet<>(subchannels.keySet()); + // No old addresses means first time through, so we will do an explicit move to CONNECTING + // which is what we implicitly started with + boolean noOldAddrs = shutdownRemovedAddresses(newImmutableAddressGroups); - // Flatten the new EAGs addresses - Set newAddrs = new HashSet<>(); - for (EquivalentAddressGroup endpoint : newImmutableAddressGroups) { - newAddrs.addAll(endpoint.getAddresses()); - } - - // Shut them down and remove them - for (SocketAddress oldAddr : oldAddrs) { - if (!newAddrs.contains(oldAddr)) { - subchannels.remove(oldAddr).getSubchannel().shutdown(); - } - } - - if (oldAddrs.size() == 0) { + if (noOldAddrs) { // Make tests happy; they don't properly assume starting in CONNECTING rawConnectivityState = CONNECTING; - updateBalancingState(CONNECTING, new Picker(PickResult.withNoResult())); + updateBalancingState(CONNECTING, new FixedResultPicker(PickResult.withNoResult())); } if (rawConnectivityState == READY) { @@ -177,6 +184,31 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { return Status.OK; } + /** + * Compute the difference between the flattened new addresses and the old addresses that had been + * made into subchannels and then shutdown the matching subchannels. + * @return true if there were no old addresses + */ + private boolean shutdownRemovedAddresses( + ImmutableList newImmutableAddressGroups) { + + Set oldAddrs = new HashSet<>(subchannels.keySet()); + + // Flatten the new EAGs addresses + Set newAddrs = new HashSet<>(); + for (EquivalentAddressGroup endpoint : newImmutableAddressGroups) { + newAddrs.addAll(endpoint.getAddresses()); + } + + // Shut them down and remove them + for (SocketAddress oldAddr : oldAddrs) { + if (!newAddrs.contains(oldAddr)) { + subchannels.remove(oldAddr).getSubchannel().shutdown(); + } + } + return oldAddrs.isEmpty(); + } + private static List deDupAddresses(List groups) { Set seenAddresses = new HashSet<>(); List newGroups = new ArrayList<>(); @@ -196,6 +228,46 @@ private static List deDupAddresses(List shuffle(List eags, Random random) { + if (weightedShuffling) { + List weightedEntries = new ArrayList<>(eags.size()); + for (EquivalentAddressGroup eag : eags) { + weightedEntries.add(new WeightEntry(eag, eagToWeight(eag, random))); + } + Collections.sort(weightedEntries, Collections.reverseOrder() /* descending */); + return Lists.transform(weightedEntries, entry -> entry.eag); + } else { + List eagsCopy = new ArrayList<>(eags); + Collections.shuffle(eagsCopy, random); + return eagsCopy; + } + } + + private static double eagToWeight(EquivalentAddressGroup eag, Random random) { + Long weight = eag.getAttributes().get(InternalEquivalentAddressGroup.ATTR_WEIGHT); + if (weight == null) { + weight = 1L; + } + return Math.pow(random.nextDouble(), 1.0 / weight); + } + + private static final class WeightEntry implements Comparable { + final EquivalentAddressGroup eag; + final double weight; + + public WeightEntry(EquivalentAddressGroup eag, double weight) { + this.eag = eag; + this.weight = weight; + } + + @Override + public int compareTo(WeightEntry entry) { + return Double.compare(this.weight, entry.weight); + } + } + @Override public void handleNameResolutionError(Status error) { if (rawConnectivityState == SHUTDOWN) { @@ -208,7 +280,7 @@ public void handleNameResolutionError(Status error) { subchannels.clear(); addressIndex.updateGroups(ImmutableList.of()); rawConnectivityState = TRANSIENT_FAILURE; - updateBalancingState(TRANSIENT_FAILURE, new Picker(PickResult.withError(error))); + updateBalancingState(TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); } void processSubchannelState(SubchannelData subchannelData, ConnectivityStateInfo stateInfo) { @@ -225,9 +297,10 @@ void processSubchannelState(SubchannelData subchannelData, ConnectivityStateInfo return; } - if (newState == IDLE) { + if (newState == IDLE && subchannelData.state == READY) { helper.refreshNameResolution(); } + // If we are transitioning from a TRANSIENT_FAILURE to CONNECTING or IDLE we ignore this state // transition and still keep the LB in TRANSIENT_FAILURE state. This is referred to as "sticky // transient failure". Only a subchannel state change to READY will get the LB out of @@ -260,7 +333,7 @@ void processSubchannelState(SubchannelData subchannelData, ConnectivityStateInfo case CONNECTING: rawConnectivityState = CONNECTING; - updateBalancingState(CONNECTING, new Picker(PickResult.withNoResult())); + updateBalancingState(CONNECTING, new FixedResultPicker(PickResult.withNoResult())); break; case READY: @@ -277,13 +350,22 @@ void processSubchannelState(SubchannelData subchannelData, ConnectivityStateInfo if (addressIndex.increment()) { cancelScheduleTask(); requestConnection(); // is recursive so might hit the end of the addresses + } else { + if (subchannels.size() >= addressIndex.size()) { + scheduleBackoff(); + } else { + // We must have done a seek to the middle of the list lets start over from the + // beginning + addressIndex.reset(); + requestConnection(); + } } } if (isPassComplete()) { rawConnectivityState = TRANSIENT_FAILURE; updateBalancingState(TRANSIENT_FAILURE, - new Picker(PickResult.withError(stateInfo.getStatus()))); + new FixedResultPicker(PickResult.withError(stateInfo.getStatus()))); // Refresh Name Resolution, but only when all 3 conditions are met // * We are at the end of addressIndex @@ -304,6 +386,39 @@ void processSubchannelState(SubchannelData subchannelData, ConnectivityStateInfo } } + /** + * Only called after all addresses attempted and failed (TRANSIENT_FAILURE). + */ + private void scheduleBackoff() { + if (!serializingRetries) { + return; + } + + class EndOfCurrentBackoff implements Runnable { + @Override + public void run() { + reconnectTask = null; + addressIndex.reset(); + requestConnection(); + } + } + + // Just allow the previous one to trigger when ready if we're already in backoff + if (reconnectTask != null) { + return; + } + + if (reconnectPolicy == null) { + reconnectPolicy = bkoffPolProvider.get(); + } + long delayNanos = reconnectPolicy.nextBackoffNanos(); + reconnectTask = helper.getSynchronizationContext().schedule( + new EndOfCurrentBackoff(), + delayNanos, + TimeUnit.NANOSECONDS, + helper.getScheduledExecutorService()); + } + private void updateHealthCheckedState(SubchannelData subchannelData) { if (subchannelData.state != READY) { return; @@ -313,11 +428,11 @@ private void updateHealthCheckedState(SubchannelData subchannelData) { updateBalancingState(READY, new FixedResultPicker(PickResult.withSubchannel(subchannelData.subchannel))); } else if (subchannelData.getHealthState() == TRANSIENT_FAILURE) { - updateBalancingState(TRANSIENT_FAILURE, new Picker(PickResult.withError( + updateBalancingState(TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError( subchannelData.healthStateInfo.getStatus()))); } else if (concludedState != TRANSIENT_FAILURE) { updateBalancingState(subchannelData.getHealthState(), - new Picker(PickResult.withNoResult())); + new FixedResultPicker(PickResult.withNoResult())); } } @@ -337,6 +452,11 @@ public void shutdown() { rawConnectivityState = SHUTDOWN; concludedState = SHUTDOWN; cancelScheduleTask(); + if (reconnectTask != null) { + reconnectTask.cancel(); + reconnectTask = null; + } + reconnectPolicy = null; for (SubchannelData subchannelData : subchannels.values()) { subchannelData.getSubchannel().shutdown(); @@ -350,6 +470,12 @@ public void shutdown() { * that all other subchannels must be shutdown. */ private void shutdownRemaining(SubchannelData activeSubchannelData) { + if (reconnectTask != null) { + reconnectTask.cancel(); + reconnectTask = null; + } + reconnectPolicy = null; + cancelScheduleTask(); for (SubchannelData subchannelData : subchannels.values()) { if (!subchannelData.getSubchannel().equals(activeSubchannelData.subchannel)) { @@ -370,7 +496,7 @@ private void shutdownRemaining(SubchannelData activeSubchannelData) { */ @Override public void requestConnection() { - if (!addressIndex.isValid() || rawConnectivityState == SHUTDOWN) { + if (!addressIndex.isValid() || rawConnectivityState == SHUTDOWN || reconnectTask != null) { return; } @@ -391,8 +517,17 @@ public void requestConnection() { scheduleNextConnection(); break; case TRANSIENT_FAILURE: - addressIndex.increment(); - requestConnection(); + if (!serializingRetries) { + addressIndex.increment(); + requestConnection(); + } else { + if (!addressIndex.isValid()) { + scheduleBackoff(); + } else { + subchannelData.subchannel.requestConnection(); + subchannelData.updateState(CONNECTING); + } + } break; default: // Wait for current subchannel to change state @@ -438,9 +573,10 @@ private SubchannelData createNewSubchannel(SocketAddress addr, Attributes attrs) HealthListener hcListener = new HealthListener(); final Subchannel subchannel = helper.createSubchannel( CreateSubchannelArgs.newBuilder() - .setAddresses(Lists.newArrayList( - new EquivalentAddressGroup(addr, attrs))) - .addOption(HEALTH_CONSUMER_LISTENER_ARG_KEY, hcListener) + .setAddresses(Lists.newArrayList( + new EquivalentAddressGroup(addr, attrs))) + .addOption(HEALTH_CONSUMER_LISTENER_ARG_KEY, hcListener) + .addOption(LoadBalancer.DISABLE_SUBCHANNEL_RECONNECT_KEY, serializingRetries) .build()); if (subchannel == null) { log.warning("Was not able to create subchannel for " + addr); @@ -458,7 +594,7 @@ private SubchannelData createNewSubchannel(SocketAddress addr, Attributes attrs) } private boolean isPassComplete() { - if (addressIndex.isValid() || subchannels.size() < addressIndex.size()) { + if (subchannels.size() < addressIndex.size()) { return false; } for (SubchannelData sc : subchannels.values()) { @@ -500,28 +636,6 @@ ConnectivityState getConcludedConnectivityState() { return this.concludedState; } - /** - * No-op picker which doesn't add any custom picking logic. It just passes already known result - * received in constructor. - */ - private static final class Picker extends SubchannelPicker { - private final PickResult result; - - Picker(PickResult result) { - this.result = checkNotNull(result, "result"); - } - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return result; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(Picker.class).add("result", result).toString(); - } - } - /** * Picker that requests connection during the first pick, and returns noResult. */ @@ -544,27 +658,26 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { } /** - * Index as in 'i', the pointer to an entry. Not a "search index." + * This contains both an ordered list of addresses and a pointer(i.e. index) to the current entry. * All updates should be done in a synchronization context. */ @VisibleForTesting static final class Index { - private List addressGroups; - private int size; - private int groupIndex; - private int addressIndex; + private List orderedAddresses; + private int activeElement = 0; + private boolean enableHappyEyeballs; - public Index(List groups) { + Index(List groups, boolean enableHappyEyeballs) { + this.enableHappyEyeballs = enableHappyEyeballs; updateGroups(groups); } public boolean isValid() { - // Is invalid if empty or has incremented off the end - return groupIndex < addressGroups.size(); + return activeElement < orderedAddresses.size(); } public boolean isAtBeginning() { - return groupIndex == 0 && addressIndex == 0; + return activeElement == 0; } /** @@ -576,74 +689,155 @@ public boolean increment() { return false; } - EquivalentAddressGroup group = addressGroups.get(groupIndex); - addressIndex++; - if (addressIndex >= group.getAddresses().size()) { - groupIndex++; - addressIndex = 0; - return groupIndex < addressGroups.size(); - } + activeElement++; - return true; + return isValid(); } public void reset() { - groupIndex = 0; - addressIndex = 0; + activeElement = 0; } public SocketAddress getCurrentAddress() { if (!isValid()) { throw new IllegalStateException("Index is past the end of the address group list"); } - return addressGroups.get(groupIndex).getAddresses().get(addressIndex); + return orderedAddresses.get(activeElement).address; } public Attributes getCurrentEagAttributes() { if (!isValid()) { throw new IllegalStateException("Index is off the end of the address group list"); } - return addressGroups.get(groupIndex).getAttributes(); + return orderedAddresses.get(activeElement).attributes; } public List getCurrentEagAsList() { - return Collections.singletonList( - new EquivalentAddressGroup(getCurrentAddress(), getCurrentEagAttributes())); + return Collections.singletonList(getCurrentEag()); + } + + private EquivalentAddressGroup getCurrentEag() { + if (!isValid()) { + throw new IllegalStateException("Index is past the end of the address group list"); + } + return orderedAddresses.get(activeElement).asEag(); } /** * Update to new groups, resetting the current index. */ public void updateGroups(List newGroups) { - addressGroups = checkNotNull(newGroups, "newGroups"); + checkNotNull(newGroups, "newGroups"); + orderedAddresses = enableHappyEyeballs + ? updateGroupsHE(newGroups) + : updateGroupsNonHE(newGroups); reset(); - int size = 0; - for (EquivalentAddressGroup eag : newGroups) { - size += eag.getAddresses().size(); - } - this.size = size; } /** * Returns false if the needle was not found and the current index was left unchanged. */ public boolean seekTo(SocketAddress needle) { - for (int i = 0; i < addressGroups.size(); i++) { - EquivalentAddressGroup group = addressGroups.get(i); - int j = group.getAddresses().indexOf(needle); - if (j == -1) { - continue; + checkNotNull(needle, "needle"); + for (int i = 0; i < orderedAddresses.size(); i++) { + if (orderedAddresses.get(i).address.equals(needle)) { + this.activeElement = i; + return true; } - this.groupIndex = i; - this.addressIndex = j; - return true; } return false; } public int size() { - return size; + return orderedAddresses.size(); } + + private List updateGroupsNonHE(List newGroups) { + List entries = new ArrayList<>(); + for (int g = 0; g < newGroups.size(); g++) { + EquivalentAddressGroup eag = newGroups.get(g); + for (int a = 0; a < eag.getAddresses().size(); a++) { + SocketAddress addr = eag.getAddresses().get(a); + entries.add(new UnwrappedEag(eag.getAttributes(), addr)); + } + } + + return entries; + } + + private List updateGroupsHE(List newGroups) { + Boolean firstIsV6 = null; + List v4Entries = new ArrayList<>(); + List v6Entries = new ArrayList<>(); + for (int g = 0; g < newGroups.size(); g++) { + EquivalentAddressGroup eag = newGroups.get(g); + for (int a = 0; a < eag.getAddresses().size(); a++) { + SocketAddress addr = eag.getAddresses().get(a); + boolean isIpV4 = addr instanceof InetSocketAddress + && ((InetSocketAddress) addr).getAddress() instanceof Inet4Address; + if (isIpV4) { + if (firstIsV6 == null) { + firstIsV6 = false; + } + v4Entries.add(new UnwrappedEag(eag.getAttributes(), addr)); + } else { + if (firstIsV6 == null) { + firstIsV6 = true; + } + v6Entries.add(new UnwrappedEag(eag.getAttributes(), addr)); + } + } + } + + return firstIsV6 != null && firstIsV6 + ? interleave(v6Entries, v4Entries) + : interleave(v4Entries, v6Entries); + } + + private List interleave(List firstFamily, + List secondFamily) { + if (firstFamily.isEmpty()) { + return secondFamily; + } + if (secondFamily.isEmpty()) { + return firstFamily; + } + + List result = new ArrayList<>(firstFamily.size() + secondFamily.size()); + for (int i = 0; i < Math.max(firstFamily.size(), secondFamily.size()); i++) { + if (i < firstFamily.size()) { + result.add(firstFamily.get(i)); + } + if (i < secondFamily.size()) { + result.add(secondFamily.get(i)); + } + } + return result; + } + + private static final class UnwrappedEag { + private final Attributes attributes; + private final SocketAddress address; + + public UnwrappedEag(Attributes attributes, SocketAddress address) { + this.attributes = attributes; + this.address = address; + } + + private EquivalentAddressGroup asEag() { + return new EquivalentAddressGroup(address, attributes); + } + } + } + + @VisibleForTesting + int getIndexLocation() { + return addressIndex.activeElement; + } + + @VisibleForTesting + boolean isIndexValid() { + return addressIndex.isValid(); } private static final class SubchannelData { @@ -702,4 +896,5 @@ public PickFirstLeafLoadBalancerConfig(@Nullable Boolean shuffleAddressList) { this.randomSeed = randomSeed; } } + } diff --git a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java index acef79d3d9f..cf4b4c94e04 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java @@ -22,14 +22,11 @@ import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import com.google.common.base.MoreObjects; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Random; import java.util.concurrent.atomic.AtomicBoolean; @@ -66,9 +63,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { PickFirstLoadBalancerConfig config = (PickFirstLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); if (config.shuffleAddressList != null && config.shuffleAddressList) { - servers = new ArrayList(servers); - Collections.shuffle(servers, - config.randomSeed != null ? new Random(config.randomSeed) : new Random()); + servers = PickFirstLeafLoadBalancer.shuffle( + servers, config.randomSeed != null ? new Random(config.randomSeed) : new Random()); } } @@ -87,7 +83,7 @@ public void onSubchannelState(ConnectivityStateInfo stateInfo) { // The channel state does not get updated when doing name resolving today, so for the moment // let LB report CONNECTION and call subchannel.requestConnection() immediately. - updateBalancingState(CONNECTING, new Picker(PickResult.withSubchannel(subchannel))); + updateBalancingState(CONNECTING, new FixedResultPicker(PickResult.withNoResult())); subchannel.requestConnection(); } else { subchannel.updateAddresses(servers); @@ -105,7 +101,7 @@ public void handleNameResolutionError(Status error) { // NB(lukaszx0) Whether we should propagate the error unconditionally is arguable. It's fine // for time being. - updateBalancingState(TRANSIENT_FAILURE, new Picker(PickResult.withError(error))); + updateBalancingState(TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); } private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { @@ -134,18 +130,18 @@ private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo SubchannelPicker picker; switch (newState) { case IDLE: - picker = new RequestConnectionPicker(subchannel); + picker = new RequestConnectionPicker(); break; case CONNECTING: // It's safe to use RequestConnectionPicker here, so when coming from IDLE we could leave // the current picker in-place. But ignoring the potential optimization is simpler. - picker = new Picker(PickResult.withNoResult()); + picker = new FixedResultPicker(PickResult.withNoResult()); break; case READY: - picker = new Picker(PickResult.withSubchannel(subchannel)); + picker = new FixedResultPicker(PickResult.withSubchannel(subchannel)); break; case TRANSIENT_FAILURE: - picker = new Picker(PickResult.withError(stateInfo.getStatus())); + picker = new FixedResultPicker(PickResult.withError(stateInfo.getStatus())); break; default: throw new IllegalArgumentException("Unsupported state:" + newState); @@ -173,46 +169,14 @@ public void requestConnection() { } } - /** - * No-op picker which doesn't add any custom picking logic. It just passes already known result - * received in constructor. - */ - private static final class Picker extends SubchannelPicker { - private final PickResult result; - - Picker(PickResult result) { - this.result = checkNotNull(result, "result"); - } - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return result; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(Picker.class).add("result", result).toString(); - } - } - /** Picker that requests connection during the first pick, and returns noResult. */ private final class RequestConnectionPicker extends SubchannelPicker { - private final Subchannel subchannel; private final AtomicBoolean connectionRequested = new AtomicBoolean(false); - RequestConnectionPicker(Subchannel subchannel) { - this.subchannel = checkNotNull(subchannel, "subchannel"); - } - @Override public PickResult pickSubchannel(PickSubchannelArgs args) { if (connectionRequested.compareAndSet(false, true)) { - helper.getSynchronizationContext().execute(new Runnable() { - @Override - public void run() { - subchannel.requestConnection(); - } - }); + helper.getSynchronizationContext().execute(PickFirstLoadBalancer.this::requestConnection); } return PickResult.withNoResult(); } diff --git a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancerProvider.java b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancerProvider.java index 92178ccae24..83b3fb7d8e6 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancerProvider.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancerProvider.java @@ -36,7 +36,7 @@ public final class PickFirstLoadBalancerProvider extends LoadBalancerProvider { public static final String GRPC_PF_USE_HAPPY_EYEBALLS = "GRPC_PF_USE_HAPPY_EYEBALLS"; private static final String SHUFFLE_ADDRESS_LIST_KEY = "shuffleAddressList"; - private static boolean enableNewPickFirst = + static boolean enableNewPickFirst = GrpcUtil.getFlag("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", false); public static boolean isEnabledHappyEyeballs() { diff --git a/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java b/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java index b3f646d6099..58c7803346f 100644 --- a/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java +++ b/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java @@ -147,18 +147,9 @@ public ProxySelector get() { } }; - /** - * Experimental environment variable name for enabling proxy support. - * - * @deprecated Use the standard Java proxy configuration instead with flags such as: - * -Dhttps.proxyHost=HOST -Dhttps.proxyPort=PORT - */ - @Deprecated - private static final String GRPC_PROXY_ENV_VAR = "GRPC_PROXY_EXP"; // Do not hard code a ProxySelector because the global default ProxySelector can change private final Supplier proxySelector; private final AuthenticationProvider authenticationProvider; - private final InetSocketAddress overrideProxyAddress; // We want an HTTPS proxy, which operates on the entire data stream (See IETF rfc2817). static final String PROXY_SCHEME = "https"; @@ -168,21 +159,15 @@ public ProxySelector get() { * {@link ProxyDetectorImpl.AuthenticationProvider} to detect proxy parameters. */ public ProxyDetectorImpl() { - this(DEFAULT_PROXY_SELECTOR, DEFAULT_AUTHENTICATOR, System.getenv(GRPC_PROXY_ENV_VAR)); + this(DEFAULT_PROXY_SELECTOR, DEFAULT_AUTHENTICATOR); } @VisibleForTesting ProxyDetectorImpl( Supplier proxySelector, - AuthenticationProvider authenticationProvider, - @Nullable String proxyEnvString) { + AuthenticationProvider authenticationProvider) { this.proxySelector = checkNotNull(proxySelector); this.authenticationProvider = checkNotNull(authenticationProvider); - if (proxyEnvString != null) { - overrideProxyAddress = overrideProxy(proxyEnvString); - } else { - overrideProxyAddress = null; - } } @Nullable @@ -191,12 +176,6 @@ public ProxiedSocketAddress proxyFor(SocketAddress targetServerAddress) throws I if (!(targetServerAddress instanceof InetSocketAddress)) { return null; } - if (overrideProxyAddress != null) { - return HttpConnectProxiedSocketAddress.newBuilder() - .setProxyAddress(overrideProxyAddress) - .setTargetAddress((InetSocketAddress) targetServerAddress) - .build(); - } return detectProxy((InetSocketAddress) targetServerAddress); } @@ -272,27 +251,6 @@ private ProxiedSocketAddress detectProxy(InetSocketAddress targetAddr) throws IO .build(); } - /** - * GRPC_PROXY_EXP is deprecated but let's maintain compatibility for now. - */ - private static InetSocketAddress overrideProxy(String proxyHostPort) { - if (proxyHostPort == null) { - return null; - } - - String[] parts = proxyHostPort.split(":", 2); - int port = 80; - if (parts.length > 1) { - port = Integer.parseInt(parts[1]); - } - log.warning( - "Detected GRPC_PROXY_EXP and will honor it, but this feature will " - + "be removed in a future release. Use the JVM flags " - + "\"-Dhttps.proxyHost=HOST -Dhttps.proxyPort=PORT\" to set the https proxy for " - + "this JVM."); - return new InetSocketAddress(parts[0], port); - } - /** * This interface makes unit testing easier by avoiding direct calls to static methods. */ diff --git a/core/src/main/java/io/grpc/internal/ReadableBuffer.java b/core/src/main/java/io/grpc/internal/ReadableBuffer.java index 6963c78203e..20f64719875 100644 --- a/core/src/main/java/io/grpc/internal/ReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/ReadableBuffer.java @@ -71,15 +71,6 @@ public interface ReadableBuffer extends Closeable { */ void readBytes(byte[] dest, int destOffset, int length); - /** - * Reads from this buffer until the destination's position reaches its limit, and increases the - * read position by the number of the transferred bytes. - * - * @param dest the destination buffer to receive the bytes. - * @throws IndexOutOfBoundsException if required bytes are not readable - */ - void readBytes(ByteBuffer dest); - /** * Reads {@code length} bytes from this buffer and writes them to the destination stream. * Increments the read position by {@code length}. If the required bytes are not readable, throws diff --git a/core/src/main/java/io/grpc/internal/ReadableBuffers.java b/core/src/main/java/io/grpc/internal/ReadableBuffers.java index e512c810f84..439745e29b2 100644 --- a/core/src/main/java/io/grpc/internal/ReadableBuffers.java +++ b/core/src/main/java/io/grpc/internal/ReadableBuffers.java @@ -171,15 +171,6 @@ public void readBytes(byte[] dest, int destIndex, int length) { offset += length; } - @Override - public void readBytes(ByteBuffer dest) { - Preconditions.checkNotNull(dest, "dest"); - int length = dest.remaining(); - checkReadable(length); - dest.put(bytes, offset, length); - offset += length; - } - @Override public void readBytes(OutputStream dest, int length) throws IOException { checkReadable(length); @@ -262,21 +253,6 @@ public void readBytes(byte[] dest, int destOffset, int length) { bytes.get(dest, destOffset, length); } - @Override - public void readBytes(ByteBuffer dest) { - Preconditions.checkNotNull(dest, "dest"); - int length = dest.remaining(); - checkReadable(length); - - // Change the limit so that only length bytes are available. - int prevLimit = bytes.limit(); - ((Buffer) bytes).limit(bytes.position() + length); - - // Write the bytes and restore the original limit. - dest.put(bytes); - bytes.limit(prevLimit); - } - @Override public void readBytes(OutputStream dest, int length) throws IOException { checkReadable(length); diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index ba9424ea25c..0c37a0beaca 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -22,6 +22,8 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; +import com.google.errorprone.annotations.CheckReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.ClientStreamTracer; import io.grpc.Compressor; @@ -47,9 +49,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.CheckForNull; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** A logical {@link ClientStream} that is retriable. */ abstract class RetriableStream implements ClientStream { @@ -166,7 +166,8 @@ private Runnable commit(final Substream winningSubstream) { final boolean wasCancelled = (scheduledRetry != null) ? scheduledRetry.isCancelled() : false; final Future retryFuture; - if (scheduledRetry != null) { + final boolean retryWasScheduled = scheduledRetry != null; + if (retryWasScheduled) { retryFuture = scheduledRetry.markCancelled(); scheduledRetry = null; } else { @@ -190,8 +191,10 @@ public void run() { substream.stream.cancel(CANCELLED_BECAUSE_COMMITTED); } } - if (retryFuture != null) { - retryFuture.cancel(false); + if (retryWasScheduled) { + if (retryFuture != null) { + retryFuture.cancel(false); + } if (!wasCancelled && inFlightSubStreams.decrementAndGet() == Integer.MIN_VALUE) { assert savedCloseMasterListenerReason != null; listenerSerializeExecutor.execute( @@ -245,7 +248,8 @@ private void commitAndRun(Substream winningSubstream) { // returns null means we should not create new sub streams, e.g. cancelled or // other close condition is met for retriableStream. @Nullable - private Substream createSubstream(int previousAttemptCount, boolean isTransparentRetry) { + private Substream createSubstream(int previousAttemptCount, boolean isTransparentRetry, + boolean isHedgedStream) { int inFlight; do { inFlight = inFlightSubStreams.get(); @@ -266,7 +270,8 @@ public ClientStreamTracer newClientStreamTracer( Metadata newHeaders = updateHeaders(headers, previousAttemptCount); // NOTICE: This set _must_ be done before stream.start() and it actually is. - sub.stream = newSubstream(newHeaders, tracerFactory, previousAttemptCount, isTransparentRetry); + sub.stream = newSubstream(newHeaders, tracerFactory, previousAttemptCount, isTransparentRetry, + isHedgedStream); return sub; } @@ -276,7 +281,7 @@ public ClientStreamTracer newClientStreamTracer( */ abstract ClientStream newSubstream( Metadata headers, ClientStreamTracer.Factory tracerFactory, int previousAttempts, - boolean isTransparentRetry); + boolean isTransparentRetry, boolean isHedgedStream); /** Adds grpc-previous-rpc-attempts in the headers of a retry/hedging RPC. */ @VisibleForTesting @@ -382,7 +387,7 @@ public void runWith(Substream substream) { } } - /** Starts the first PRC attempt. */ + /** Starts the first RPC attempt. */ @Override public final void start(ClientStreamListener listener) { masterListener = listener; @@ -398,7 +403,7 @@ public final void start(ClientStreamListener listener) { state.buffer.add(new StartEntry()); } - Substream substream = createSubstream(0, false); + Substream substream = createSubstream(0, false, false); if (substream == null) { return; } @@ -471,7 +476,7 @@ public void run() { // If this run is not cancelled, the value of state.hedgingAttemptCount won't change // until state.addActiveHedge() is called subsequently, even the state could possibly // change. - Substream newSubstream = createSubstream(state.hedgingAttemptCount, false); + Substream newSubstream = createSubstream(state.hedgingAttemptCount, false, true); if (newSubstream == null) { return; } @@ -846,6 +851,15 @@ public void run() { } } + private static final boolean isExperimentalRetryJitterEnabled = GrpcUtil + .getFlag("GRPC_EXPERIMENTAL_XDS_RLS_LB", true); + + public static long intervalWithJitter(long intervalNanos) { + double inverseJitterFactor = isExperimentalRetryJitterEnabled + ? 0.4 * random.nextDouble() + 0.8 : random.nextDouble(); + return (long) (intervalNanos * inverseJitterFactor); + } + private static final class SavedCloseMasterListenerReason { private final Status status; private final RpcProgress progress; @@ -927,9 +941,8 @@ public void run() { && localOnlyTransparentRetries.incrementAndGet() > 1_000) { commitAndRun(substream); if (state.winningSubstream == substream) { - Status tooManyTransparentRetries = Status.INTERNAL - .withDescription("Too many transparent retries. Might be a bug in gRPC") - .withCause(status.asRuntimeException()); + Status tooManyTransparentRetries = GrpcUtil.statusWithDetails( + Status.Code.INTERNAL, "Too many transparent retries. Might be a bug in gRPC", status); safeCloseMasterListener(tooManyTransparentRetries, rpcProgress, trailers); } return; @@ -940,7 +953,8 @@ public void run() { || (rpcProgress == RpcProgress.REFUSED && noMoreTransparentRetry.compareAndSet(false, true))) { // transparent retry - final Substream newSubstream = createSubstream(substream.previousAttemptCount, true); + final Substream newSubstream = createSubstream(substream.previousAttemptCount, + true, false); if (newSubstream == null) { return; } @@ -992,7 +1006,8 @@ public void run() { RetryPlan retryPlan = makeRetryDecision(status, trailers); if (retryPlan.shouldRetry) { // retry - Substream newSubstream = createSubstream(substream.previousAttemptCount + 1, false); + Substream newSubstream = createSubstream(substream.previousAttemptCount + 1, + false, false); if (newSubstream == null) { return; } @@ -1066,7 +1081,7 @@ private RetryPlan makeRetryDecision(Status status, Metadata trailer) { if (pushbackMillis == null) { if (isRetryableStatusCode) { shouldRetry = true; - backoffNanos = (long) (nextBackoffIntervalNanos * random.nextDouble()); + backoffNanos = intervalWithJitter(nextBackoffIntervalNanos); nextBackoffIntervalNanos = Math.min( (long) (nextBackoffIntervalNanos * retryPolicy.backoffMultiplier), retryPolicy.maxBackoffNanos); diff --git a/core/src/main/java/io/grpc/internal/RetryingNameResolver.java b/core/src/main/java/io/grpc/internal/RetryingNameResolver.java index 6dcfcd3534a..90827fa8acb 100644 --- a/core/src/main/java/io/grpc/internal/RetryingNameResolver.java +++ b/core/src/main/java/io/grpc/internal/RetryingNameResolver.java @@ -17,7 +17,6 @@ package io.grpc.internal; import com.google.common.annotations.VisibleForTesting; -import io.grpc.Attributes; import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.SynchronizationContext; @@ -28,16 +27,22 @@ * *

The {@link NameResolver} used with this */ -final class RetryingNameResolver extends ForwardingNameResolver { +public final class RetryingNameResolver extends ForwardingNameResolver { + public static NameResolver wrap(NameResolver retriedNameResolver, Args args) { + // For migration, this might become conditional + return new RetryingNameResolver( + retriedNameResolver, + new BackoffPolicyRetryScheduler( + new ExponentialBackoffPolicy.Provider(), + args.getScheduledExecutorService(), + args.getSynchronizationContext()), + args.getSynchronizationContext()); + } private final NameResolver retriedNameResolver; private final RetryScheduler retryScheduler; private final SynchronizationContext syncContext; - static final Attributes.Key RESOLUTION_RESULT_LISTENER_KEY - = Attributes.Key.create( - "io.grpc.internal.RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY"); - /** * Creates a new {@link RetryingNameResolver}. * @@ -88,18 +93,7 @@ private class RetryingListener extends Listener2 { @Override public void onResult(ResolutionResult resolutionResult) { - // If the resolution result listener is already an attribute it indicates that a name resolver - // has already been wrapped with this class. This indicates a misconfiguration. - if (resolutionResult.getAttributes().get(RESOLUTION_RESULT_LISTENER_KEY) != null) { - throw new IllegalStateException( - "RetryingNameResolver can only be used once to wrap a NameResolver"); - } - - // To have retry behavior for name resolvers that haven't migrated to onResult2. - delegateListener.onResult(resolutionResult.toBuilder().setAttributes( - resolutionResult.getAttributes().toBuilder() - .set(RESOLUTION_RESULT_LISTENER_KEY, new ResolutionResultListener()).build()) - .build()); + syncContext.execute(() -> onResult2(resolutionResult)); } @Override @@ -119,19 +113,4 @@ public void onError(Status error) { syncContext.execute(() -> retryScheduler.schedule(new DelayedNameResolverRefresh())); } } - - /** - * Simple callback class to store in {@link ResolutionResult} attributes so that - * ManagedChannel can indicate if the resolved addresses were accepted. Temporary until - * the Listener2.onResult() API can be changed to return a boolean for this purpose. - */ - class ResolutionResultListener { - public void resolutionAttempted(Status successStatus) { - if (successStatus.isOk()) { - retryScheduler.reset(); - } else { - retryScheduler.schedule(new DelayedNameResolverRefresh()); - } - } - } } diff --git a/core/src/main/java/io/grpc/internal/ScParser.java b/core/src/main/java/io/grpc/internal/ScParser.java index f94449f7c7b..71d6d33877f 100644 --- a/core/src/main/java/io/grpc/internal/ScParser.java +++ b/core/src/main/java/io/grpc/internal/ScParser.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; @@ -31,18 +32,18 @@ public final class ScParser extends NameResolver.ServiceConfigParser { private final boolean retryEnabled; private final int maxRetryAttemptsLimit; private final int maxHedgedAttemptsLimit; - private final AutoConfiguredLoadBalancerFactory autoLoadBalancerFactory; + private final LoadBalancerProvider parser; /** Creates a parse with global retry settings and an auto configured lb factory. */ public ScParser( boolean retryEnabled, int maxRetryAttemptsLimit, int maxHedgedAttemptsLimit, - AutoConfiguredLoadBalancerFactory autoLoadBalancerFactory) { + LoadBalancerProvider parser) { this.retryEnabled = retryEnabled; this.maxRetryAttemptsLimit = maxRetryAttemptsLimit; this.maxHedgedAttemptsLimit = maxHedgedAttemptsLimit; - this.autoLoadBalancerFactory = checkNotNull(autoLoadBalancerFactory, "autoLoadBalancerFactory"); + this.parser = checkNotNull(parser, "parser"); } @Override @@ -50,7 +51,9 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { try { Object loadBalancingPolicySelection; ConfigOrError choiceFromLoadBalancer = - autoLoadBalancerFactory.parseLoadBalancerPolicy(rawServiceConfig); + parser.parseLoadBalancingPolicyConfig(rawServiceConfig); + // TODO(ejona): The Provider API doesn't allow null, but AutoConfiguredLoadBalancerFactory can + // return null and it will need tweaking to ManagedChannelImpl.defaultServiceConfig to fix. if (choiceFromLoadBalancer == null) { loadBalancingPolicySelection = null; } else if (choiceFromLoadBalancer.getError() != null) { @@ -66,8 +69,19 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { maxHedgedAttemptsLimit, loadBalancingPolicySelection)); } catch (RuntimeException e) { + // TODO(ejona): We really don't want parsers throwing exceptions; they should return an error. + // However, right now ManagedChannelServiceConfig itself uses exceptions like + // ClassCastException. We should handle those with a graceful return within + // ManagedChannelServiceConfig and then get rid of this case. Then all exceptions are + // "unexpected" and the INTERNAL status code makes it clear a bug needs to be fixed. return ConfigOrError.fromError( Status.UNKNOWN.withDescription("failed to parse service config").withCause(e)); + } catch (Throwable t) { + // Even catch Errors, since broken config parsing could trigger AssertionError, + // StackOverflowError, and other errors we can reasonably safely recover. Since the config + // could be untrusted, we want to error on the side of recovering. + return ConfigOrError.fromError( + Status.INTERNAL.withDescription("Unexpected error parsing service config").withCause(t)); } } } diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java index dda36258e7c..e224384ce8f 100644 --- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java @@ -373,10 +373,10 @@ private void closedInternal(Status status) { } else { call.cancelled = true; listener.onCancel(); - // The status will not have a cause in all failure scenarios but we want to make sure + // The status will not have a cause in all failure scenarios, but we want to make sure // we always cancel the context with one to keep the context cancelled state consistent. - cancelCause = InternalStatus.asRuntimeException( - Status.CANCELLED.withDescription("RPC cancelled"), null, false); + cancelCause = InternalStatus.asRuntimeExceptionWithoutStacktrace( + Status.CANCELLED.withDescription("RPC cancelled"), null); } } finally { // Cancel context after delivering RPC closure notification to allow the application to diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index cec2a13a301..dc0709e1fb8 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -31,6 +31,7 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.BinaryLog; import io.grpc.CompressorRegistry; @@ -75,7 +76,6 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.concurrent.GuardedBy; /** * Default implementation of {@link io.grpc.Server}, for creation by transports. @@ -887,8 +887,8 @@ private void closedInternal(final Status status) { // failed status has an exception we will create one here if needed. Throwable cancelCause = status.getCause(); if (cancelCause == null) { - cancelCause = InternalStatus.asRuntimeException( - Status.CANCELLED.withDescription("RPC cancelled"), null, false); + cancelCause = InternalStatus.asRuntimeExceptionWithoutStacktrace( + Status.CANCELLED.withDescription("RPC cancelled"), null); } // The callExecutor might be busy doing user work. To avoid waiting, use an executor that diff --git a/core/src/main/java/io/grpc/internal/ServerImplBuilder.java b/core/src/main/java/io/grpc/internal/ServerImplBuilder.java index b679baf3a8b..62a0e66f314 100644 --- a/core/src/main/java/io/grpc/internal/ServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ServerImplBuilder.java @@ -31,6 +31,9 @@ import io.grpc.HandlerRegistry; import io.grpc.InternalChannelz; import io.grpc.InternalConfiguratorRegistry; +import io.grpc.MetricInstrumentRegistry; +import io.grpc.MetricRecorder; +import io.grpc.MetricSink; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.ServerCallExecutorSupplier; @@ -80,6 +83,7 @@ public static ServerBuilder forPort(int port) { final List transportFilters = new ArrayList<>(); final List interceptors = new ArrayList<>(); private final List streamTracerFactories = new ArrayList<>(); + final List metricSinks = new ArrayList<>(); private final ClientTransportServersBuilder clientTransportServersBuilder; HandlerRegistry fallbackRegistry = DEFAULT_FALLBACK_REGISTRY; ObjectPool executorPool = DEFAULT_EXECUTOR_POOL; @@ -99,12 +103,13 @@ public static ServerBuilder forPort(int port) { ServerCallExecutorSupplier executorSupplier; /** - * An interface to provide to provide transport specific information for the server. This method + * An interface to provide transport specific information for the server. This method * is meant for Transport implementors and should not be used by normal users. */ public interface ClientTransportServersBuilder { InternalServer buildClientTransportServers( - List streamTracerFactories); + List streamTracerFactories, + MetricRecorder metricRecorder); } /** @@ -157,6 +162,15 @@ public ServerImplBuilder intercept(ServerInterceptor interceptor) { return this; } + /** + * Adds a MetricSink to the server. + */ + @Override + public ServerImplBuilder addMetricSink(MetricSink metricSink) { + metricSinks.add(checkNotNull(metricSink, "metricSink")); + return this; + } + @Override public ServerImplBuilder addStreamTracerFactory(ServerStreamTracer.Factory factory) { streamTracerFactories.add(checkNotNull(factory, "factory")); @@ -241,8 +255,11 @@ public void setDeadlineTicker(Deadline.Ticker ticker) { @Override public Server build() { + MetricRecorder metricRecorder = new MetricRecorderImpl(metricSinks, + MetricInstrumentRegistry.getDefaultRegistry()); return new ServerImpl(this, - clientTransportServersBuilder.buildClientTransportServers(getTracerFactories()), + clientTransportServersBuilder.buildClientTransportServers( + getTracerFactories(), metricRecorder), Context.ROOT); } diff --git a/core/src/main/java/io/grpc/internal/SharedResourceHolder.java b/core/src/main/java/io/grpc/internal/SharedResourceHolder.java index 67d1a98b545..1dfa1f90718 100644 --- a/core/src/main/java/io/grpc/internal/SharedResourceHolder.java +++ b/core/src/main/java/io/grpc/internal/SharedResourceHolder.java @@ -134,18 +134,16 @@ synchronized T releaseInternal(final Resource resource, final T instance) public void run() { synchronized (SharedResourceHolder.this) { // Refcount may have gone up since the task was scheduled. Re-check it. - if (cached.refcount == 0) { - try { - resource.close(instance); - } finally { - instances.remove(resource); - if (instances.isEmpty()) { - destroyer.shutdown(); - destroyer = null; - } - } + if (cached.refcount != 0) { + return; + } + instances.remove(resource); + if (instances.isEmpty()) { + destroyer.shutdown(); + destroyer = null; } } + resource.close(instance); } }), DESTROY_DELAY_SECONDS, TimeUnit.SECONDS); } diff --git a/core/src/main/java/io/grpc/internal/SimpleDisconnectError.java b/core/src/main/java/io/grpc/internal/SimpleDisconnectError.java new file mode 100644 index 00000000000..addbfbe10a3 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/SimpleDisconnectError.java @@ -0,0 +1,68 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import javax.annotation.concurrent.Immutable; + +/** + * Represents a fixed, static reason for disconnection. + */ +@Immutable +public enum SimpleDisconnectError implements DisconnectError { + /** + * The subchannel was shut down for various reasons like parent channel shutdown, + * idleness, or load balancing policy changes. + */ + SUBCHANNEL_SHUTDOWN("subchannel shutdown"), + + /** + * Connection was reset (e.g., ECONNRESET, WSAECONNERESET). + */ + CONNECTION_RESET("connection reset"), + + /** + * Connection timed out (e.g., ETIMEDOUT, WSAETIMEDOUT), including closures + * from gRPC keepalives. + */ + CONNECTION_TIMED_OUT("connection timed out"), + + /** + * Connection was aborted (e.g., ECONNABORTED, WSAECONNABORTED). + */ + CONNECTION_ABORTED("connection aborted"), + + /** + * Any socket error not covered by other specific disconnect errors. + */ + SOCKET_ERROR("socket error"), + + /** + * A catch-all for any other unclassified reason. + */ + UNKNOWN("unknown"); + + private final String errorTag; + + SimpleDisconnectError(String errorTag) { + this.errorTag = errorTag; + } + + @Override + public String toErrorString() { + return this.errorTag; + } +} diff --git a/core/src/main/java/io/grpc/internal/SpiffeUtil.java b/core/src/main/java/io/grpc/internal/SpiffeUtil.java new file mode 100644 index 00000000000..9eafc9950e2 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/SpiffeUtil.java @@ -0,0 +1,312 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.base.Optional; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.io.Files; +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.CertificateParsingException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * Provides utilities to manage SPIFFE bundles, extract SPIFFE IDs from X.509 certificate chains, + * and parse SPIFFE IDs. + * @see Standard + */ +public final class SpiffeUtil { + + private static final Integer URI_SAN_TYPE = 6; + private static final String USE_PARAMETER_VALUE = "x509-svid"; + private static final ImmutableSet KTY_PARAMETER_VALUES = ImmutableSet.of("RSA", "EC"); + private static final String CERTIFICATE_PREFIX = "-----BEGIN CERTIFICATE-----\n"; + private static final String CERTIFICATE_SUFFIX = "-----END CERTIFICATE-----"; + private static final String PREFIX = "spiffe://"; + + private SpiffeUtil() {} + + /** + * Parses a URI string, applies validation rules described in SPIFFE standard, and, in case of + * success, returns parsed TrustDomain and Path. + * + * @param uri a String representing a SPIFFE ID + */ + public static SpiffeId parse(String uri) { + doInitialUriValidation(uri); + checkArgument(uri.toLowerCase(Locale.US).startsWith(PREFIX), "Spiffe Id must start with " + + PREFIX); + String domainAndPath = uri.substring(PREFIX.length()); + String trustDomain; + String path; + if (!domainAndPath.contains("/")) { + trustDomain = domainAndPath; + path = ""; + } else { + String[] parts = domainAndPath.split("/", 2); + trustDomain = parts[0]; + path = parts[1]; + checkArgument(!path.isEmpty(), "Path must not include a trailing '/'"); + } + validateTrustDomain(trustDomain); + validatePath(path); + if (!path.isEmpty()) { + path = "/" + path; + } + return new SpiffeId(trustDomain, path); + } + + private static void doInitialUriValidation(String uri) { + checkArgument(checkNotNull(uri, "uri").length() > 0, "Spiffe Id can't be empty"); + checkArgument(uri.length() <= 2048, "Spiffe Id maximum length is 2048 characters"); + checkArgument(!uri.contains("#"), "Spiffe Id must not contain query fragments"); + checkArgument(!uri.contains("?"), "Spiffe Id must not contain query parameters"); + } + + private static void validateTrustDomain(String trustDomain) { + checkArgument(!trustDomain.isEmpty(), "Trust Domain can't be empty"); + checkArgument(trustDomain.length() < 256, "Trust Domain maximum length is 255 characters"); + checkArgument(trustDomain.matches("[a-z0-9._-]+"), + "Trust Domain must contain only letters, numbers, dots, dashes, and underscores" + + " ([a-z0-9.-_])"); + } + + private static void validatePath(String path) { + if (path.isEmpty()) { + return; + } + checkArgument(!path.endsWith("/"), "Path must not include a trailing '/'"); + for (String segment : Splitter.on("/").split(path)) { + validatePathSegment(segment); + } + } + + private static void validatePathSegment(String pathSegment) { + checkArgument(!pathSegment.isEmpty(), "Individual path segments must not be empty"); + checkArgument(!(pathSegment.equals(".") || pathSegment.equals("..")), + "Individual path segments must not be relative path modifiers (i.e. ., ..)"); + checkArgument(pathSegment.matches("[a-zA-Z0-9._-]+"), + "Individual path segments must contain only letters, numbers, dots, dashes, and underscores" + + " ([a-zA-Z0-9.-_])"); + } + + /** + * Returns the SPIFFE ID from the leaf certificate, if present. + * + * @param certChain certificate chain to extract SPIFFE ID from + */ + public static Optional extractSpiffeId(X509Certificate[] certChain) + throws CertificateParsingException { + checkArgument(checkNotNull(certChain, "certChain").length > 0, "certChain can't be empty"); + Collection> subjectAltNames = certChain[0].getSubjectAlternativeNames(); + if (subjectAltNames == null) { + return Optional.absent(); + } + String uri = null; + // Search for the unique URI SAN. + for (List altName : subjectAltNames) { + if (altName.size() < 2 ) { + continue; + } + if (URI_SAN_TYPE.equals(altName.get(0))) { + if (uri != null) { + throw new IllegalArgumentException("Multiple URI SAN values found in the leaf cert."); + } + uri = (String) altName.get(1); + } + } + if (uri == null) { + return Optional.absent(); + } + return Optional.of(parse(uri)); + } + + /** + * Loads a SPIFFE trust bundle from a file, parsing it from the JSON format. + * In case of success, returns {@link SpiffeBundle}. + * If any element of the JSON content is invalid or unsupported, an + * {@link IllegalArgumentException} is thrown and the entire Bundle is considered invalid. + * + * @param trustBundleFile the file path to the JSON file containing the trust bundle + * @see JSON format + * @see JWK entry format + * @see x5c (certificate) parameter + */ + public static SpiffeBundle loadTrustBundleFromFile(String trustBundleFile) throws IOException { + Map trustDomainsNode = readTrustDomainsFromFile(trustBundleFile); + Map> trustBundleMap = new HashMap<>(); + Map sequenceNumbers = new HashMap<>(); + for (String trustDomainName : trustDomainsNode.keySet()) { + Map domainNode = JsonUtil.getObject(trustDomainsNode, trustDomainName); + if (domainNode.size() == 0) { + trustBundleMap.put(trustDomainName, Collections.emptyList()); + continue; + } + Long sequenceNumber = JsonUtil.getNumberAsLong(domainNode, "spiffe_sequence"); + sequenceNumbers.put(trustDomainName, sequenceNumber == null ? -1L : sequenceNumber); + List> keysNode = JsonUtil.getListOfObjects(domainNode, "keys"); + if (keysNode == null || keysNode.size() == 0) { + trustBundleMap.put(trustDomainName, Collections.emptyList()); + continue; + } + trustBundleMap.put(trustDomainName, extractCert(keysNode, trustDomainName)); + } + return new SpiffeBundle(sequenceNumbers, trustBundleMap); + } + + private static Map readTrustDomainsFromFile(String filePath) throws IOException { + File file = new File(checkNotNull(filePath, "trustBundleFile")); + String json = new String(Files.toByteArray(file), StandardCharsets.UTF_8); + Object jsonObject = JsonParser.parse(json); + if (!(jsonObject instanceof Map)) { + throw new IllegalArgumentException( + "SPIFFE Trust Bundle should be a JSON object. Found: " + + (jsonObject == null ? null : jsonObject.getClass())); + } + @SuppressWarnings("unchecked") + Map root = (Map)jsonObject; + Map trustDomainsNode = JsonUtil.getObject(root, "trust_domains"); + checkNotNull(trustDomainsNode, "Mandatory trust_domains element is missing"); + checkArgument(trustDomainsNode.size() > 0, "Mandatory trust_domains element is missing"); + return trustDomainsNode; + } + + private static void checkJwkEntry(Map jwkNode, String trustDomainName) { + String kty = JsonUtil.getString(jwkNode, "kty"); + if (kty == null || !KTY_PARAMETER_VALUES.contains(kty)) { + throw new IllegalArgumentException( + String.format( + "'kty' parameter must be one of %s but '%s' " + + "found. Certificate loading for trust domain '%s' failed.", + KTY_PARAMETER_VALUES, kty, trustDomainName)); + } + if (jwkNode.containsKey("kid")) { + throw new IllegalArgumentException(String.format("'kid' parameter must not be set. " + + "Certificate loading for trust domain '%s' failed.", trustDomainName)); + } + String use = JsonUtil.getString(jwkNode, "use"); + if (use == null || !use.equals(USE_PARAMETER_VALUE)) { + throw new IllegalArgumentException(String.format("'use' parameter must be '%s' but '%s' " + + "found. Certificate loading for trust domain '%s' failed.", USE_PARAMETER_VALUE, + use, trustDomainName)); + } + } + + private static List extractCert(List> keysNode, + String trustDomainName) { + List result = new ArrayList<>(); + for (Map keyNode : keysNode) { + checkJwkEntry(keyNode, trustDomainName); + List rawCerts = JsonUtil.getListOfStrings(keyNode, "x5c"); + if (rawCerts == null) { + break; + } + if (rawCerts.size() != 1) { + throw new IllegalArgumentException(String.format("Exactly 1 certificate is expected, but " + + "%s found. Certificate loading for trust domain '%s' failed.", rawCerts.size(), + trustDomainName)); + } + InputStream stream = new ByteArrayInputStream((CERTIFICATE_PREFIX + rawCerts.get(0) + "\n" + + CERTIFICATE_SUFFIX) + .getBytes(StandardCharsets.UTF_8)); + try { + Collection certs = CertificateFactory.getInstance("X509") + .generateCertificates(stream); + X509Certificate[] certsArray = certs.toArray(new X509Certificate[0]); + assert certsArray.length == 1; + result.add(certsArray[0]); + } catch (CertificateException e) { + throw new IllegalArgumentException(String.format("Certificate can't be parsed. Certificate " + + "loading for trust domain '%s' failed.", trustDomainName), e); + } + } + return result; + } + + /** + * Represents a SPIFFE ID as defined in the SPIFFE standard. + * @see Standard + */ + public static class SpiffeId { + + private final String trustDomain; + private final String path; + + private SpiffeId(String trustDomain, String path) { + this.trustDomain = trustDomain; + this.path = path; + } + + public String getTrustDomain() { + return trustDomain; + } + + public String getPath() { + return path; + } + } + + /** + * Represents a SPIFFE trust bundle; that is, a map from trust domain to set of trusted + * certificates. Only trust domain's sequence numbers and x509 certificates are supported. + * @see Standard + */ + public static final class SpiffeBundle { + + private final ImmutableMap sequenceNumbers; + + private final ImmutableMap> bundleMap; + + private SpiffeBundle(Map sequenceNumbers, + Map> trustDomainMap) { + this.sequenceNumbers = ImmutableMap.copyOf(sequenceNumbers); + ImmutableMap.Builder> builder = ImmutableMap.builder(); + for (Map.Entry> entry : trustDomainMap.entrySet()) { + builder.put(entry.getKey(), ImmutableList.copyOf(entry.getValue())); + } + this.bundleMap = builder.build(); + } + + public ImmutableMap getSequenceNumbers() { + return sequenceNumbers; + } + + public ImmutableMap> getBundleMap() { + return bundleMap; + } + } + +} diff --git a/core/src/main/java/io/grpc/internal/SubchannelChannel.java b/core/src/main/java/io/grpc/internal/SubchannelChannel.java index 773dcb99dd7..ced4272afe3 100644 --- a/core/src/main/java/io/grpc/internal/SubchannelChannel.java +++ b/core/src/main/java/io/grpc/internal/SubchannelChannel.java @@ -59,7 +59,8 @@ public ClientStream newStream(MethodDescriptor method, transport = notReadyTransport; } ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( - callOptions, headers, 0, /* isTransparentRetry= */ false); + callOptions, headers, 0, /* isTransparentRetry= */ false, + /* isHedging= */ false); Context origContext = context.attach(); try { return transport.newStream(method, headers, callOptions, tracers); diff --git a/core/src/main/java/io/grpc/internal/SubchannelMetrics.java b/core/src/main/java/io/grpc/internal/SubchannelMetrics.java new file mode 100644 index 00000000000..4bc2cf47046 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/SubchannelMetrics.java @@ -0,0 +1,108 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.LongUpDownCounterMetricInstrument; +import io.grpc.MetricInstrumentRegistry; +import io.grpc.MetricRecorder; + +final class SubchannelMetrics { + + private static final LongCounterMetricInstrument disconnections; + private static final LongCounterMetricInstrument connectionAttemptsSucceeded; + private static final LongCounterMetricInstrument connectionAttemptsFailed; + private static final LongUpDownCounterMetricInstrument openConnections; + private final MetricRecorder metricRecorder; + + public SubchannelMetrics(MetricRecorder metricRecorder) { + this.metricRecorder = metricRecorder; + } + + static { + MetricInstrumentRegistry metricInstrumentRegistry + = MetricInstrumentRegistry.getDefaultRegistry(); + disconnections = metricInstrumentRegistry.registerLongCounter( + "grpc.subchannel.disconnections", + "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected", + "{disconnection}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.backend_service", "grpc.lb.locality", "grpc.disconnect_error"), + false + ); + + connectionAttemptsSucceeded = metricInstrumentRegistry.registerLongCounter( + "grpc.subchannel.connection_attempts_succeeded", + "EXPERIMENTAL. Number of successful connection attempts", + "{attempt}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.backend_service", "grpc.lb.locality"), + false + ); + + connectionAttemptsFailed = metricInstrumentRegistry.registerLongCounter( + "grpc.subchannel.connection_attempts_failed", + "EXPERIMENTAL. Number of failed connection attempts", + "{attempt}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.backend_service", "grpc.lb.locality"), + false + ); + + openConnections = metricInstrumentRegistry.registerLongUpDownCounter( + "grpc.subchannel.open_connections", + "EXPERIMENTAL. Number of open connections.", + "{connection}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.security_level", "grpc.lb.backend_service", "grpc.lb.locality"), + false + ); + } + + public void recordConnectionAttemptSucceeded(String target, String backendService, + String locality, String securityLevel) { + metricRecorder + .addLongCounter(connectionAttemptsSucceeded, 1, + ImmutableList.of(target), + ImmutableList.of(backendService, locality)); + metricRecorder + .addLongUpDownCounter(openConnections, 1, + ImmutableList.of(target), + ImmutableList.of(securityLevel, backendService, locality)); + } + + public void recordConnectionAttemptFailed(String target, String backendService, String locality) { + metricRecorder + .addLongCounter(connectionAttemptsFailed, 1, + ImmutableList.of(target), + ImmutableList.of(backendService, locality)); + } + + public void recordDisconnection(String target, String backendService, String locality, + String disconnectError, String securityLevel) { + metricRecorder + .addLongCounter(disconnections, 1, + ImmutableList.of(target), + ImmutableList.of(backendService, locality, disconnectError)); + metricRecorder + .addLongUpDownCounter(openConnections, -1, + ImmutableList.of(target), + ImmutableList.of(securityLevel, backendService, locality)); + } +} diff --git a/core/src/main/java/io/grpc/internal/TimeProvider.java b/core/src/main/java/io/grpc/internal/TimeProvider.java index b0ea147ada1..3bd052ab3e0 100644 --- a/core/src/main/java/io/grpc/internal/TimeProvider.java +++ b/core/src/main/java/io/grpc/internal/TimeProvider.java @@ -16,8 +16,6 @@ package io.grpc.internal; -import java.util.concurrent.TimeUnit; - /** * Time source representing the current system time in nanos. Used to inject a fake clock * into unit tests. @@ -26,10 +24,5 @@ public interface TimeProvider { /** Returns the current nano time. */ long currentTimeNanos(); - TimeProvider SYSTEM_TIME_PROVIDER = new TimeProvider() { - @Override - public long currentTimeNanos() { - return TimeUnit.MILLISECONDS.toNanos(System.currentTimeMillis()); - } - }; + TimeProvider SYSTEM_TIME_PROVIDER = TimeProviderResolverFactory.resolveTimeProvider(); } diff --git a/core/src/main/java/io/grpc/internal/TimeProviderResolverFactory.java b/core/src/main/java/io/grpc/internal/TimeProviderResolverFactory.java new file mode 100644 index 00000000000..04272034ce9 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/TimeProviderResolverFactory.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +/** + * {@link TimeProviderResolverFactory} resolves Time providers. + */ + +final class TimeProviderResolverFactory { + static TimeProvider resolveTimeProvider() { + try { + Class.forName("java.time.Instant"); + return new InstantTimeProvider(); + } catch (ClassNotFoundException ex) { + return new ConcurrentTimeProvider(); + } + } +} diff --git a/core/src/main/java/io/grpc/internal/TransportFrameUtil.java b/core/src/main/java/io/grpc/internal/TransportFrameUtil.java index f3c32416426..3bd7ee72239 100644 --- a/core/src/main/java/io/grpc/internal/TransportFrameUtil.java +++ b/core/src/main/java/io/grpc/internal/TransportFrameUtil.java @@ -19,13 +19,13 @@ import static java.nio.charset.StandardCharsets.US_ASCII; import com.google.common.io.BaseEncoding; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.InternalMetadata; import io.grpc.Metadata; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; /** * Utility functions for transport layer framing. diff --git a/core/src/main/java/io/grpc/internal/UriWrapper.java b/core/src/main/java/io/grpc/internal/UriWrapper.java new file mode 100644 index 00000000000..ca5835cabd8 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/UriWrapper.java @@ -0,0 +1,139 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.grpc.NameResolver; +import io.grpc.Uri; +import java.net.URI; +import javax.annotation.Nullable; + +/** Temporary wrapper for a URI-like object to ease the migration to io.grpc.Uri. */ +interface UriWrapper { + + static UriWrapper wrap(URI uri) { + return new JavaNetUriWrapper(uri); + } + + static UriWrapper wrap(Uri uri) { + return new IoGrpcUriWrapper(uri); + } + + /** Uses the given factory and args to create a {@link NameResolver} for this URI. */ + NameResolver newNameResolver(NameResolver.Factory factory, NameResolver.Args args); + + /** Returns the scheme component of this URI, e.g. "http", "mailto" or "dns". */ + String getScheme(); + + /** + * Returns the authority component of this URI, e.g. "google.com", "127.0.0.1:8080", or null if + * not present. + */ + @Nullable + String getAuthority(); + + /** Wraps an instance of java.net.URI. */ + final class JavaNetUriWrapper implements UriWrapper { + private final URI uri; + + private JavaNetUriWrapper(URI uri) { + this.uri = checkNotNull(uri); + } + + @Override + public NameResolver newNameResolver(NameResolver.Factory factory, NameResolver.Args args) { + return factory.newNameResolver(uri, args); + } + + @Override + public String getScheme() { + return uri.getScheme(); + } + + @Override + public String getAuthority() { + return uri.getAuthority(); + } + + @Override + public String toString() { + return uri.toString(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + if (!(other instanceof JavaNetUriWrapper)) { + return false; + } + return uri.equals(((JavaNetUriWrapper) other).uri); + } + + @Override + public int hashCode() { + return uri.hashCode(); + } + } + + /** Wraps an instance of io.grpc.Uri. */ + final class IoGrpcUriWrapper implements UriWrapper { + private final Uri uri; + + private IoGrpcUriWrapper(Uri uri) { + this.uri = checkNotNull(uri); + } + + @Override + public NameResolver newNameResolver(NameResolver.Factory factory, NameResolver.Args args) { + return factory.newNameResolver(uri, args); + } + + @Override + public String getScheme() { + return uri.getScheme(); + } + + @Override + public String getAuthority() { + return uri.getAuthority(); + } + + @Override + public String toString() { + return uri.toString(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + if (!(other instanceof IoGrpcUriWrapper)) { + return false; + } + return uri.equals(((IoGrpcUriWrapper) other).uri); + } + + @Override + public int hashCode() { + return uri.hashCode(); + } + } +} diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java index ad3b59030d7..8f14b74035c 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; @@ -57,7 +58,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -76,8 +76,6 @@ public class AbstractClientStreamTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); private final StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP; private final TransportTracer transportTracer = new TransportTracer(); @@ -136,9 +134,7 @@ public void cancel_failsOnNull() { AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer); stream.start(listener); - thrown.expect(NullPointerException.class); - - stream.cancel(null); + assertThrows(NullPointerException.class, () -> stream.cancel(null)); } @Test @@ -164,9 +160,7 @@ public void startFailsOnNullListener() { AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer); - thrown.expect(NullPointerException.class); - - stream.start(null); + assertThrows(NullPointerException.class, () -> stream.start(null)); } @Test @@ -174,9 +168,7 @@ public void cantCallStartTwice() { AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer); stream.start(mockListener); - thrown.expect(IllegalStateException.class); - - stream.start(mockListener); + assertThrows(IllegalStateException.class, () -> stream.start(mockListener)); } @Test @@ -188,8 +180,7 @@ public void inboundDataReceived_failsOnNullFrame() { TransportState state = stream.transportState(); - thrown.expect(NullPointerException.class); - state.inboundDataReceived(null); + assertThrows(NullPointerException.class, () -> state.inboundDataReceived(null)); } @Test @@ -212,8 +203,8 @@ public void inboundHeadersReceived_failsIfStatusReported() { TransportState state = stream.transportState(); - thrown.expect(IllegalStateException.class); - state.inboundHeadersReceived(new Metadata()); + Metadata headers = new Metadata(); + assertThrows(IllegalStateException.class, () -> state.inboundHeadersReceived(headers)); } @Test @@ -474,6 +465,24 @@ allocator, new BaseTransportState(statsTraceCtx, transportTracer), sink, statsTr .isGreaterThan(TimeUnit.MILLISECONDS.toNanos(600)); } + @Test + public void setDeadline_thePastBecomesPositive() { + AbstractClientStream.Sink sink = mock(AbstractClientStream.Sink.class); + ClientStream stream = new BaseAbstractClientStream( + allocator, new BaseTransportState(statsTraceCtx, transportTracer), sink, statsTraceCtx, + transportTracer); + + stream.setDeadline(Deadline.after(-1, TimeUnit.NANOSECONDS)); + stream.start(mockListener); + + ArgumentCaptor headersCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(sink).writeHeaders(headersCaptor.capture(), ArgumentMatchers.any()); + + Metadata headers = headersCaptor.getValue(); + assertThat(headers.get(Metadata.Key.of("grpc-timeout", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("1n"); + } + @Test public void appendTimeoutInsight() { InsightBuilder insight = new InsightBuilder(); diff --git a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java index 66fa92b1cf8..137ba19bfea 100644 --- a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java @@ -18,6 +18,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; @@ -45,9 +46,7 @@ import java.util.Queue; import java.util.concurrent.TimeUnit; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -60,9 +59,6 @@ public class AbstractServerStreamTest { private static final int TIMEOUT_MS = 1000; private static final int MAX_MESSAGE_SIZE = 100; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); - private final WritableBufferAllocator allocator = new WritableBufferAllocator() { @Override public WritableBuffer allocate(int capacityHint) { @@ -85,7 +81,7 @@ public void setUp() { } /** - * Test for issue https://github.com/grpc/grpc-java/issues/1795 + * Test for issue https://github.com/grpc/grpc-java/issues/1795 . */ @Test public void frameShouldBeIgnoredAfterDeframerClosed() { @@ -212,7 +208,7 @@ public void closed(Status status) { } /** - * Test for issue https://github.com/grpc/grpc-java/issues/615 + * Test for issue https://github.com/grpc/grpc-java/issues/615 . */ @Test public void completeWithoutClose() { @@ -226,9 +222,9 @@ public void completeWithoutClose() { public void setListener_setOnlyOnce() { TransportState state = stream.transportState(); state.setListener(new ServerStreamListenerBase()); - thrown.expect(IllegalStateException.class); - state.setListener(new ServerStreamListenerBase()); + ServerStreamListenerBase listener2 = new ServerStreamListenerBase(); + assertThrows(IllegalStateException.class, () -> state.setListener(listener2)); } @Test @@ -238,8 +234,7 @@ public void listenerReady_onlyOnce() { TransportState state = stream.transportState(); - thrown.expect(IllegalStateException.class); - state.onStreamAllocated(); + assertThrows(IllegalStateException.class, state::onStreamAllocated); } @Test @@ -255,8 +250,7 @@ public void listenerReady_readyCalled() { public void setListener_failsOnNull() { TransportState state = stream.transportState(); - thrown.expect(NullPointerException.class); - state.setListener(null); + assertThrows(NullPointerException.class, () -> state.setListener(null)); } // TODO(ericgribkoff) This test is only valid if deframeInTransportThread=true, as otherwise the @@ -284,9 +278,7 @@ public void messagesAvailable(MessageProducer producer) { @Test public void writeHeaders_failsOnNullHeaders() { - thrown.expect(NullPointerException.class); - - stream.writeHeaders(null, true); + assertThrows(NullPointerException.class, () -> stream.writeHeaders(null, true)); } @Test @@ -336,16 +328,13 @@ public void writeMessage_closesStream() throws Exception { @Test public void close_failsOnNullStatus() { - thrown.expect(NullPointerException.class); - - stream.close(null, new Metadata()); + Metadata trailers = new Metadata(); + assertThrows(NullPointerException.class, () -> stream.close(null, trailers)); } @Test public void close_failsOnNullMetadata() { - thrown.expect(NullPointerException.class); - - stream.close(Status.INTERNAL, null); + assertThrows(NullPointerException.class, () -> stream.close(Status.INTERNAL, null)); } @Test @@ -451,4 +440,3 @@ public int streamId() { } } } - diff --git a/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java b/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java index 8d56968737b..07d19d41a86 100644 --- a/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java +++ b/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java @@ -193,7 +193,7 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); LoadBalancer oldDelegate = lb.getDelegate(); - Status addressAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setAttributes(Attributes.EMPTY) @@ -208,7 +208,7 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { public void acceptResolvedAddresses_shutsDownOldBalancer() throws Exception { Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"round_robin\": { } } ] }"); - ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError lbConfigs = lbf.parseLoadBalancingPolicyConfig(serviceConfig); final List servers = Collections.singletonList(new EquivalentAddressGroup(new SocketAddress(){})); @@ -235,7 +235,7 @@ public void shutdown() { }; lb.setDelegate(testlb); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -252,7 +252,7 @@ public void shutdown() { public void acceptResolvedAddresses_propagateLbConfigToDelegate() throws Exception { Map rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { \"setting1\": \"high\" } } ] }"); - ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); + ConfigOrError lbConfigs = lbf.parseLoadBalancingPolicyConfig(rawServiceConfig); assertThat(lbConfigs.getConfig()).isNotNull(); final List servers = @@ -260,7 +260,7 @@ public void acceptResolvedAddresses_propagateLbConfigToDelegate() throws Excepti Helper helper = new TestHelper(); AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -280,9 +280,9 @@ public void acceptResolvedAddresses_propagateLbConfigToDelegate() throws Excepti rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { \"setting1\": \"low\" } } ] }"); - lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); + lbConfigs = lbf.parseLoadBalancingPolicyConfig(rawServiceConfig); - addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -305,7 +305,7 @@ public void acceptResolvedAddresses_propagateLbConfigToDelegate() throws Excepti public void acceptResolvedAddresses_propagateAddrsToDelegate() throws Exception { Map rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { \"setting1\": \"high\" } } ] }"); - ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); + ConfigOrError lbConfigs = lbf.parseLoadBalancingPolicyConfig(rawServiceConfig); assertThat(lbConfigs.getConfig()).isNotNull(); Helper helper = new TestHelper(); @@ -313,7 +313,7 @@ public void acceptResolvedAddresses_propagateAddrsToDelegate() throws Exception List servers = Collections.singletonList(new EquivalentAddressGroup(new InetSocketAddress(8080){})); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -329,7 +329,7 @@ public void acceptResolvedAddresses_propagateAddrsToDelegate() throws Exception servers = Collections.singletonList(new EquivalentAddressGroup(new InetSocketAddress(9090){})); - addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -353,8 +353,8 @@ public void acceptResolvedAddresses_delegateDoNotAcceptEmptyAddressList_nothing( Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { \"setting1\": \"high\" } } ] }"); - ConfigOrError lbConfig = lbf.parseLoadBalancerPolicy(serviceConfig); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + ConfigOrError lbConfig = lbf.parseLoadBalancingPolicyConfig(serviceConfig); + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setLoadBalancingPolicyConfig(lbConfig.getConfig()) @@ -373,8 +373,8 @@ public void acceptResolvedAddresses_delegateAcceptsEmptyAddressList() Map rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb2\": { \"setting1\": \"high\" } } ] }"); ConfigOrError lbConfigs = - lbf.parseLoadBalancerPolicy(rawServiceConfig); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + lbf.parseLoadBalancingPolicyConfig(rawServiceConfig); + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -394,7 +394,7 @@ public void acceptResolvedAddresses_delegateAcceptsEmptyAddressList() public void acceptResolvedAddresses_useSelectedLbPolicy() throws Exception { Map rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [{\"round_robin\": {}}]}"); - ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); + ConfigOrError lbConfigs = lbf.parseLoadBalancingPolicyConfig(rawServiceConfig); assertThat(lbConfigs.getConfig()).isNotNull(); assertThat(((PolicySelection) lbConfigs.getConfig()).provider.getClass().getName()) .isEqualTo("io.grpc.util.SecretRoundRobinLoadBalancerProvider$Provider"); @@ -409,7 +409,7 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { } }; AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -431,7 +431,7 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { } }; AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(null) @@ -446,7 +446,7 @@ public void acceptResolvedAddresses_noLbPolicySelected_defaultToCustomDefault() .newLoadBalancer(new TestHelper()); List servers = Collections.singletonList(new EquivalentAddressGroup(new SocketAddress(){})); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(null) @@ -468,7 +468,7 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { AutoConfiguredLoadBalancer lb = new AutoConfiguredLoadBalancerFactory(GrpcUtil.DEFAULT_LB_POLICY).newLoadBalancer(helper); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setAttributes(Attributes.EMPTY) @@ -481,8 +481,8 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { nextParsedConfigOrError.set(testLbParsedConfig); Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { } } ] }"); - ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(serviceConfig); - addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + ConfigOrError lbConfigs = lbf.parseLoadBalancingPolicyConfig(serviceConfig); + addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -504,8 +504,8 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { testLbParsedConfig = ConfigOrError.fromConfig("bar"); nextParsedConfigOrError.set(testLbParsedConfig); serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { } } ] }"); - lbConfigs = lbf.parseLoadBalancerPolicy(serviceConfig); - addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + lbConfigs = lbf.parseLoadBalancingPolicyConfig(serviceConfig); + addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -519,33 +519,33 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { } @Test - public void parseLoadBalancerConfig_failedOnUnknown() throws Exception { + public void parseLoadBalancingConfig_failedOnUnknown() throws Exception { Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"magic_balancer\": {} } ] }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed.getError()).isNotNull(); assertThat(parsed.getError().getDescription()) .isEqualTo("None of [magic_balancer] specified by Service Config are available."); } @Test - public void parseLoadBalancerPolicy_failedOnUnknown() throws Exception { + public void parseLoadBalancingPolicy_failedOnUnknown() throws Exception { Map serviceConfig = parseConfig("{\"loadBalancingPolicy\": \"magic_balancer\"}"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed.getError()).isNotNull(); assertThat(parsed.getError().getDescription()) .isEqualTo("None of [magic_balancer] specified by Service Config are available."); } @Test - public void parseLoadBalancerConfig_multipleValidPolicies() throws Exception { + public void parseLoadBalancingConfig_multipleValidPolicies() throws Exception { Map serviceConfig = parseConfig( "{\"loadBalancingConfig\": [" + "{\"round_robin\": {}}," + "{\"test_lb\": {} } ] }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed).isNotNull(); assertThat(parsed.getError()).isNull(); assertThat(parsed.getConfig()).isInstanceOf(PolicySelection.class); @@ -554,12 +554,12 @@ public void parseLoadBalancerConfig_multipleValidPolicies() throws Exception { } @Test - public void parseLoadBalancerConfig_policyShouldBeIgnoredIfConfigExists() throws Exception { + public void parseLoadBalancingConfig_policyShouldBeIgnoredIfConfigExists() throws Exception { Map serviceConfig = parseConfig( "{\"loadBalancingConfig\": [{\"round_robin\": {} } ]," + "\"loadBalancingPolicy\": \"pick_first\" }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed).isNotNull(); assertThat(parsed.getError()).isNull(); assertThat(parsed.getConfig()).isInstanceOf(PolicySelection.class); @@ -568,13 +568,13 @@ public void parseLoadBalancerConfig_policyShouldBeIgnoredIfConfigExists() throws } @Test - public void parseLoadBalancerConfig_policyShouldBeIgnoredEvenIfUnknownPolicyExists() + public void parseLoadBalancingConfig_policyShouldBeIgnoredEvenIfUnknownPolicyExists() throws Exception { Map serviceConfig = parseConfig( "{\"loadBalancingConfig\": [{\"magic_balancer\": {} } ]," + "\"loadBalancingPolicy\": \"round_robin\" }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed.getError()).isNotNull(); assertThat(parsed.getError().getDescription()) .isEqualTo("None of [magic_balancer] specified by Service Config are available."); @@ -582,7 +582,7 @@ public void parseLoadBalancerConfig_policyShouldBeIgnoredEvenIfUnknownPolicyExis @Test @SuppressWarnings("unchecked") - public void parseLoadBalancerConfig_firstInvalidPolicy() throws Exception { + public void parseLoadBalancingConfig_firstInvalidPolicy() throws Exception { when(testLbBalancerProvider.parseLoadBalancingPolicyConfig(any(Map.class))) .thenReturn(ConfigOrError.fromError(Status.UNKNOWN)); Map serviceConfig = @@ -590,7 +590,7 @@ public void parseLoadBalancerConfig_firstInvalidPolicy() throws Exception { "{\"loadBalancingConfig\": [" + "{\"test_lb\": {}}," + "{\"round_robin\": {} } ] }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed).isNotNull(); assertThat(parsed.getConfig()).isNull(); assertThat(parsed.getError()).isEqualTo(Status.UNKNOWN); @@ -598,7 +598,7 @@ public void parseLoadBalancerConfig_firstInvalidPolicy() throws Exception { @Test @SuppressWarnings("unchecked") - public void parseLoadBalancerConfig_firstValidSecondInvalidPolicy() throws Exception { + public void parseLoadBalancingConfig_firstValidSecondInvalidPolicy() throws Exception { when(testLbBalancerProvider.parseLoadBalancingPolicyConfig(any(Map.class))) .thenReturn(ConfigOrError.fromError(Status.UNKNOWN)); Map serviceConfig = @@ -606,32 +606,32 @@ public void parseLoadBalancerConfig_firstValidSecondInvalidPolicy() throws Excep "{\"loadBalancingConfig\": [" + "{\"round_robin\": {}}," + "{\"test_lb\": {} } ] }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed).isNotNull(); assertThat(parsed.getConfig()).isNotNull(); assertThat(((PolicySelection) parsed.getConfig()).config).isNotNull(); } @Test - public void parseLoadBalancerConfig_someProvidesAreNotAvailable() throws Exception { + public void parseLoadBalancingConfig_someProvidesAreNotAvailable() throws Exception { Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ " + "{\"magic_balancer\": {} }," + "{\"round_robin\": {}} ] }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed).isNotNull(); assertThat(parsed.getConfig()).isNotNull(); assertThat(((PolicySelection) parsed.getConfig()).config).isNotNull(); } @Test - public void parseLoadBalancerConfig_lbConfigPropagated() throws Exception { + public void parseLoadBalancingConfig_lbConfigPropagated() throws Exception { Map rawServiceConfig = parseConfig( "{\"loadBalancingConfig\": [" + "{\"pick_first\": {\"shuffleAddressList\": true } }" + "] }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(rawServiceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(rawServiceConfig); assertThat(parsed).isNotNull(); assertThat(parsed.getConfig()).isNotNull(); PolicySelection policySelection = (PolicySelection) parsed.getConfig(); diff --git a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java index 66d626ec2b6..03e613e13d9 100644 --- a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java @@ -1105,6 +1105,32 @@ public void getAttributes() { assertEquals(attrs, call.getAttributes()); } + @Test + public void onCloseExceptionCaughtAndLogged() { + DelayedExecutor executor = new DelayedExecutor(); + ClientCallImpl call = new ClientCallImpl<>( + method, + executor, + baseCallOptions, + clientStreamProvider, + deadlineCancellationExecutor, + channelCallTracer, configSelector); + + call.start(callListener, new Metadata()); + verify(stream).start(listenerArgumentCaptor.capture()); + final ClientStreamListener streamListener = listenerArgumentCaptor.getValue(); + streamListener.headersRead(new Metadata()); + + doThrow(new RuntimeException("Exception thrown by onClose() in ClientCall")).when(callListener) + .onClose(any(Status.class), any(Metadata.class)); + + Status status = Status.RESOURCE_EXHAUSTED.withDescription("simulated"); + streamListener.closed(status, PROCESSED, new Metadata()); + executor.release(); + + verify(callListener).onClose(same(status), any(Metadata.class)); + } + private static final class DelayedExecutor implements Executor { private final BlockingQueue commands = new LinkedBlockingQueue<>(); diff --git a/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java b/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java index 8d9248a8910..749b71d681e 100644 --- a/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java +++ b/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java @@ -28,8 +28,6 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.nio.Buffer; -import java.nio.ByteBuffer; import java.nio.InvalidMarkException; import org.junit.After; import org.junit.Before; @@ -121,27 +119,6 @@ public void readByteArrayShouldSucceed() { assertEquals(EXPECTED_VALUE, new String(bytes, UTF_8)); } - @Test - public void readByteBufferShouldSucceed() { - ByteBuffer byteBuffer = ByteBuffer.allocate(EXPECTED_VALUE.length()); - int remaining = EXPECTED_VALUE.length(); - - ((Buffer) byteBuffer).limit(1); - composite.readBytes(byteBuffer); - remaining--; - assertEquals(remaining, composite.readableBytes()); - - ((Buffer) byteBuffer).limit(byteBuffer.limit() + 5); - composite.readBytes(byteBuffer); - remaining -= 5; - assertEquals(remaining, composite.readableBytes()); - - ((Buffer) byteBuffer).limit(byteBuffer.limit() + remaining); - composite.readBytes(byteBuffer); - assertEquals(0, composite.readableBytes()); - assertEquals(EXPECTED_VALUE, new String(byteBuffer.array(), UTF_8)); - } - @Test public void readStreamShouldSucceed() throws IOException { ByteArrayOutputStream bos = new ByteArrayOutputStream(); @@ -216,18 +193,6 @@ public void markAndResetWithReadByteArrayShouldSucceed() { assertArrayEquals(first, second); } - @Test - public void markAndResetWithReadByteBufferShouldSucceed() { - byte[] first = new byte[EXPECTED_VALUE.length()]; - composite.mark(); - composite.readBytes(ByteBuffer.wrap(first)); - composite.reset(); - byte[] second = new byte[EXPECTED_VALUE.length()]; - assertEquals(EXPECTED_VALUE.length(), composite.readableBytes()); - composite.readBytes(ByteBuffer.wrap(second)); - assertArrayEquals(first, second); - } - @Test public void markAndResetWithReadStreamShouldSucceed() throws IOException { ByteArrayOutputStream first = new ByteArrayOutputStream(); diff --git a/core/src/test/java/io/grpc/internal/ConcurrentTimeProviderTest.java b/core/src/test/java/io/grpc/internal/ConcurrentTimeProviderTest.java new file mode 100644 index 00000000000..7983530456c --- /dev/null +++ b/core/src/test/java/io/grpc/internal/ConcurrentTimeProviderTest.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.truth.Truth.assertThat; + +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link ConcurrentTimeProvider}. + */ +@RunWith(JUnit4.class) +public class ConcurrentTimeProviderTest { + @Test + public void testConcurrentCurrentTimeNanos() { + + ConcurrentTimeProvider concurrentTimeProvider = new ConcurrentTimeProvider(); + // Get the current time from the ConcurrentTimeProvider + long actualTimeNanos = concurrentTimeProvider.currentTimeNanos(); + + // Get the current time from System for comparison + long expectedTimeNanos = TimeUnit.MILLISECONDS.toNanos(System.currentTimeMillis()); + + // Validate the time returned is close to the expected value within a tolerance + // (i,e 10 millisecond tolerance in nanoseconds). + assertThat(actualTimeNanos).isWithin(10_000_000L).of(expectedTimeNanos); + } +} diff --git a/core/src/test/java/io/grpc/internal/ConnectivityStateManagerTest.java b/core/src/test/java/io/grpc/internal/ConnectivityStateManagerTest.java index 2a759a4f386..dfd6ed56a1e 100644 --- a/core/src/test/java/io/grpc/internal/ConnectivityStateManagerTest.java +++ b/core/src/test/java/io/grpc/internal/ConnectivityStateManagerTest.java @@ -27,9 +27,7 @@ import io.grpc.ConnectivityState; import java.util.LinkedList; import java.util.concurrent.Executor; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -38,10 +36,6 @@ */ @RunWith(JUnit4.class) public class ConnectivityStateManagerTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - private final FakeClock executor = new FakeClock(); private final ConnectivityStateManager state = new ConnectivityStateManager(); private final LinkedList sink = new LinkedList<>(); @@ -75,7 +69,7 @@ public void run() { assertEquals(1, sink.size()); assertEquals(TRANSIENT_FAILURE, sink.poll()); } - + @Test public void registerCallbackAfterStateChanged() { state.gotoState(CONNECTING); diff --git a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java index 45682b3a385..ff131d29975 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java @@ -151,10 +151,12 @@ public void startThenSetCall() { delayedClientCall.request(1); Runnable r = delayedClientCall.setCall(mockRealCall); assertThat(r).isNotNull(); - r.run(); @SuppressWarnings("unchecked") ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + // start() must be called before setCall() returns (not in runnable), to ensure the in-use + // counts keeping the channel alive after shutdown() don't momentarily decrease to zero. verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + r.run(); Listener realCallListener = listenerCaptor.getValue(); verify(mockRealCall).request(1); realCallListener.onMessage(1); @@ -204,7 +206,7 @@ public void delayedCallsRunUnderContext() throws Exception { Object goldenValue = new Object(); DelayedClientCall delayedClientCall = Context.current().withValue(contextKey, goldenValue).call(() -> - new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null)); + new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null)); AtomicReference readyContext = new AtomicReference<>(); delayedClientCall.start(new ClientCall.Listener() { @Override public void onReady() { diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index c7ae8c8b4be..d7e1d4ca4f6 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -175,7 +175,8 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.reprocess(mockPicker); assertEquals(0, delayedTransport.getPendingStreamsCount()); delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); assertEquals(0, fakeExecutor.runDueTasks()); verify(mockRealTransport).newStream( @@ -187,7 +188,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void transportTerminatedThenAssignTransport() { delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); delayedTransport.reprocess(mockPicker); verifyNoMoreInteractions(transportListener); @@ -196,7 +198,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void assignTransportThenShutdownThenNewStream() { delayedTransport.reprocess(mockPicker); delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); ClientStream stream = delayedTransport.newStream( method, headers, callOptions, tracers); @@ -210,7 +213,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void assignTransportThenShutdownNowThenNewStream() { delayedTransport.reprocess(mockPicker); delayedTransport.shutdownNow(Status.UNAVAILABLE); - verify(transportListener).transportShutdown(any(Status.class)); + verify(transportListener).transportShutdown(any(Status.class), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); ClientStream stream = delayedTransport.newStream( method, headers, callOptions, tracers); @@ -241,7 +245,8 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); // Stream is still buffered - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener, times(0)).transportTerminated(); assertEquals(1, delayedTransport.getPendingStreamsCount()); @@ -275,7 +280,8 @@ public void uncaughtException(Thread t, Throwable e) { ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener, times(0)).transportTerminated(); assertEquals(1, delayedTransport.getPendingStreamsCount()); stream.start(streamListener); @@ -288,7 +294,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void shutdownThenNewStream() { delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); @@ -303,7 +310,8 @@ public void uncaughtException(Thread t, Throwable e) { method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); delayedTransport.shutdownNow(Status.UNAVAILABLE); - verify(transportListener).transportShutdown(any(Status.class)); + verify(transportListener).transportShutdown(any(Status.class), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); verify(streamListener) .closed(statusCaptor.capture(), eq(RpcProgress.REFUSED), any(Metadata.class)); @@ -312,7 +320,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void shutdownNowThenNewStream() { delayedTransport.shutdownNow(Status.UNAVAILABLE); - verify(transportListener).transportShutdown(any(Status.class)); + verify(transportListener).transportShutdown(any(Status.class), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); @@ -487,7 +496,8 @@ public void uncaughtException(Thread t, Throwable e) { // wfr5 will stop delayed transport from terminating delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener, never()).transportTerminated(); // ... until it's gone picker = mock(SubchannelPicker.class); @@ -502,6 +512,31 @@ public void uncaughtException(Thread t, Throwable e) { verify(transportListener).transportTerminated(); } + @Test + public void reprocess_authorityOverrideFromLb() { + InOrder inOrder = inOrder(mockRealStream); + DelayedStream delayedStream = (DelayedStream) delayedTransport.newStream( + method, headers, callOptions.withAuthority(null), tracers); + delayedStream.setAuthority("authority-override-from-calloptions"); + delayedStream.start(mock(ClientStreamListener.class)); + SubchannelPicker picker = mock(SubchannelPicker.class); + PickResult pickResult = PickResult.withSubchannel( + mockSubchannel, null, "authority-override-hostname-from-lb"); + when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult); + when(mockRealTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) + .thenReturn(mockRealStream); + + delayedTransport.reprocess(picker); + fakeExecutor.runDueTasks(); + + // Must be set before start(), and may be overwritten + inOrder.verify(mockRealStream).setAuthority("authority-override-hostname-from-lb"); + inOrder.verify(mockRealStream).setAuthority("authority-override-from-calloptions"); + inOrder.verify(mockRealStream).start(any(ClientStreamListener.class)); + } + @Test public void reprocess_NoPendingStream() { SubchannelPicker picker = mock(SubchannelPicker.class); @@ -525,6 +560,53 @@ public void reprocess_NoPendingStream() { assertSame(mockRealStream, stream); } + @Test + public void newStream_authorityOverrideFromLb() { + InOrder inOrder = inOrder(mockRealStream); + SubchannelPicker picker = mock(SubchannelPicker.class); + PickResult pickResult = PickResult.withSubchannel( + mockSubchannel, null, "authority-override-hostname-from-lb"); + when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult); + when(mockRealTransport.newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), any())) + .thenReturn(mockRealStream); + delayedTransport.reprocess(picker); + + ClientStream stream = delayedTransport.newStream(method, headers, callOptions, tracers); + assertThat(stream).isSameInstanceAs(mockRealStream); + stream.setAuthority("authority-override-from-calloptions"); + stream.start(mock(ClientStreamListener.class)); + + // Must be set before start(), and may be overwritten + inOrder.verify(mockRealStream).setAuthority("authority-override-hostname-from-lb"); + inOrder.verify(mockRealStream).setAuthority("authority-override-from-calloptions"); + inOrder.verify(mockRealStream).start(any(ClientStreamListener.class)); + } + + @Test + public void newStream_assignsTransport_authorityFromLB() { + SubchannelPicker picker = mock(SubchannelPicker.class); + AbstractSubchannel subchannel = mock(AbstractSubchannel.class); + when(subchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel); + PickResult pickResult = PickResult.withSubchannel( + subchannel, null, "authority-override-hostname-from-lb"); + when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult); + ArgumentCaptor callOptionsArgumentCaptor = + ArgumentCaptor.forClass(CallOptions.class); + when(mockRealTransport.newStream( + any(MethodDescriptor.class), any(Metadata.class), callOptionsArgumentCaptor.capture(), + ArgumentMatchers.any())) + .thenReturn(mockRealStream); + delayedTransport.reprocess(picker); + verifyNoMoreInteractions(picker); + verifyNoMoreInteractions(transportListener); + + CallOptions callOptions = CallOptions.DEFAULT; + delayedTransport.newStream(method, headers, callOptions, tracers); + assertThat(callOptionsArgumentCaptor.getValue().getAuthority()).isEqualTo( + "authority-override-hostname-from-lb"); + } + @Test public void reprocess_newStreamRacesWithReprocess() throws Exception { final CyclicBarrier barrier = new CyclicBarrier(2); @@ -670,7 +752,24 @@ public void pendingStream_appendTimeoutInsight_waitForReady() { InsightBuilder insight = new InsightBuilder(); stream.appendTimeoutInsight(insight); assertThat(insight.toString()) - .matches("\\[wait_for_ready, buffered_nanos=[0-9]+\\, waiting_for_connection]"); + .matches("\\[wait_for_ready, connecting_and_lb_delay=[0-9]+ns\\, was_still_waiting]"); + } + + @Test + public void pendingStream_appendTimeoutInsight_waitForReady_withLastPickFailure() { + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions.withWaitForReady(), tracers); + stream.start(streamListener); + SubchannelPicker picker = mock(SubchannelPicker.class); + when(picker.pickSubchannel(any(PickSubchannelArgs.class))) + .thenReturn(PickResult.withError(Status.PERMISSION_DENIED)); + delayedTransport.reprocess(picker); + InsightBuilder insight = new InsightBuilder(); + stream.appendTimeoutInsight(insight); + assertThat(insight.toString()) + .matches("\\[wait_for_ready, " + + "Last Pick Failure=Status\\{code=PERMISSION_DENIED, description=null, cause=null\\}," + + " connecting_and_lb_delay=[0-9]+ns, was_still_waiting]"); } private static TransportProvider newTransportProvider(final ClientTransport transport) { diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index e39e8d420a2..12c32fcf126 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -71,7 +71,7 @@ public class DelayedStreamTest { @Mock private ClientStreamListener listener; @Mock private ClientStream realStream; @Captor private ArgumentCaptor listenerCaptor; - private DelayedStream stream = new DelayedStream(); + private DelayedStream stream = new DelayedStream("test_op"); @Test public void setStream_setAuthority() { @@ -84,12 +84,6 @@ public void setStream_setAuthority() { inOrder.verify(realStream).start(any(ClientStreamListener.class)); } - @Test(expected = IllegalStateException.class) - public void setAuthority_afterStart() { - stream.start(listener); - stream.setAuthority("notgonnawork"); - } - @Test(expected = IllegalStateException.class) public void start_afterStart() { stream.start(listener); @@ -456,7 +450,7 @@ public void appendTimeoutInsight_realStreamNotSet() { InsightBuilder insight = new InsightBuilder(); stream.start(listener); stream.appendTimeoutInsight(insight); - assertThat(insight.toString()).matches("\\[buffered_nanos=[0-9]+\\, waiting_for_connection]"); + assertThat(insight.toString()).matches("\\[test_op_delay=[0-9]+ns\\, was_still_waiting]"); } @Test @@ -475,7 +469,7 @@ public Void answer(InvocationOnMock in) { InsightBuilder insight = new InsightBuilder(); stream.appendTimeoutInsight(insight); assertThat(insight.toString()) - .matches("\\[buffered_nanos=[0-9]+, remote_addr=127\\.0\\.0\\.1:443\\]"); + .matches("\\[test_op_delay=[0-9]+ns, remote_addr=127\\.0\\.0\\.1:443\\]"); } private void callMeMaybe(Runnable r) { diff --git a/core/src/test/java/io/grpc/internal/DnsNameResolverProviderTest.java b/core/src/test/java/io/grpc/internal/DnsNameResolverProviderTest.java index aff10ce9337..75b82df544f 100644 --- a/core/src/test/java/io/grpc/internal/DnsNameResolverProviderTest.java +++ b/core/src/test/java/io/grpc/internal/DnsNameResolverProviderTest.java @@ -16,8 +16,9 @@ package io.grpc.internal; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.TruthJUnit.assume; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; @@ -25,16 +26,27 @@ import io.grpc.NameResolver; import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.SynchronizationContext; +import io.grpc.Uri; import java.net.URI; +import java.util.Arrays; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** Unit tests for {@link DnsNameResolverProvider}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class DnsNameResolverProviderTest { private final FakeClock fakeClock = new FakeClock(); + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @Override @@ -59,10 +71,75 @@ public void isAvailable() { } @Test - public void newNameResolver() { - assertSame(DnsNameResolver.class, - provider.newNameResolver(URI.create("dns:///localhost:443"), args).getClass()); - assertNull( - provider.newNameResolver(URI.create("notdns:///localhost:443"), args)); + public void newNameResolver_acceptsHostAndPort() { + NameResolver nameResolver = newNameResolver("dns:///localhost:443", args); + assertThat(nameResolver).isNotNull(); + assertThat(nameResolver.getClass()).isSameInstanceAs(DnsNameResolver.class); + assertThat(nameResolver.getServiceAuthority()).isEqualTo("localhost:443"); + assertThat(((DnsNameResolver) nameResolver).getPort()).isEqualTo(443); + } + + @Test + public void newNameResolver_acceptsRootless() { + assume().that(enableRfc3986UrisParam).isTrue(); + NameResolver nameResolver = newNameResolver("dns:localhost:443", args); + assertThat(nameResolver).isNotNull(); + assertThat(nameResolver.getClass()).isSameInstanceAs(DnsNameResolver.class); + assertThat(nameResolver.getServiceAuthority()).isEqualTo("localhost:443"); + } + + @Test + public void newNameResolver_rejectsNonDnsScheme() { + NameResolver nameResolver = newNameResolver("notdns:///localhost:443", args); + assertThat(nameResolver).isNull(); + } + + @Test + public void newNameResolver_validDnsNameWithoutPort_usesDefaultPort() { + DnsNameResolver nameResolver = + (DnsNameResolver) newNameResolver("dns:/foo.googleapis.com", args); + assertThat(nameResolver).isNotNull(); + assertThat(nameResolver.getServiceAuthority()).isEqualTo("foo.googleapis.com"); + assertThat(nameResolver.getPort()).isEqualTo(args.getDefaultPort()); + } + + // TODO(jdcormie): Trailing path segments *should* be forbidden. This test just demonstrates that + // both newNameResolver() overloads behave the same with respect to this bug. + @Test + public void newNameResolver_toleratesTrailingPathSegments() { + NameResolver nameResolver = newNameResolver("dns:///foo.googleapis.com/ig/nor/ed", args); + assertThat(nameResolver).isNotNull(); + assertThat(nameResolver.getClass()).isSameInstanceAs(DnsNameResolver.class); + assertThat(nameResolver.getServiceAuthority()).isEqualTo("foo.googleapis.com"); + } + + @Test + public void newNameResolver_toleratesAuthority() { + NameResolver nameResolver = newNameResolver("dns://8.8.8.8/foo.googleapis.com", args); + assertThat(nameResolver).isNotNull(); + assertThat(nameResolver.getClass()).isSameInstanceAs(DnsNameResolver.class); + assertThat(nameResolver.getServiceAuthority()).isEqualTo("foo.googleapis.com"); + } + + @Test + public void newNameResolver_validIpv6Host() { + NameResolver nameResolver = newNameResolver("dns:/%5B::1%5D", args); + assertThat(nameResolver).isNotNull(); + assertThat(nameResolver.getClass()).isSameInstanceAs(DnsNameResolver.class); + assertThat(nameResolver.getServiceAuthority()).isEqualTo("[::1]"); + } + + @Test + public void newNameResolver_invalidIpv6Host_throws() { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> newNameResolver("dns:/%5Binvalid%5D", args)); + assertThat(e).hasMessageThat().contains("invalid"); + } + + private NameResolver newNameResolver(String uriString, NameResolver.Args args) { + return enableRfc3986UrisParam + ? provider.newNameResolver(Uri.create(uriString), args) + : provider.newNameResolver(URI.create(uriString), args); } } diff --git a/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java b/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java index 0512171f4e7..c53863dcf5d 100644 --- a/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java @@ -17,13 +17,16 @@ package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.internal.DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; @@ -34,12 +37,14 @@ import static org.mockito.Mockito.when; import com.google.common.base.Stopwatch; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.net.InetAddresses; import com.google.common.testing.FakeTicker; import io.grpc.ChannelLogger; import io.grpc.EquivalentAddressGroup; +import io.grpc.FlagResetRule; import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; @@ -60,7 +65,6 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; -import java.net.URI; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Arrays; @@ -75,13 +79,10 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.regex.Pattern; -import javax.annotation.Nullable; -import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.DisableOnDebug; -import org.junit.rules.ExpectedException; import org.junit.rules.TestRule; import org.junit.rules.Timeout; import org.junit.runner.RunWith; @@ -98,8 +99,7 @@ public class DnsNameResolverTest { @Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(10)); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); + @Rule public final FlagResetRule flagResetRule = new FlagResetRule(); private final Map serviceConfig = new LinkedHashMap<>(); @@ -112,7 +112,6 @@ public void uncaughtException(Thread t, Throwable e) { } }); - private final DnsNameResolverProvider provider = new DnsNameResolverProvider(); private final FakeClock fakeClock = new FakeClock(); private final FakeClock fakeExecutor = new FakeClock(); private static final FakeClock.TaskFilter NAME_RESOLVER_REFRESH_TASK_FILTER = @@ -139,37 +138,19 @@ public Executor create() { public void close(Executor instance) {} } - private final NameResolver.Args args = NameResolver.Args.newBuilder() - .setDefaultPort(DEFAULT_PORT) - .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR) - .setSynchronizationContext(syncContext) - .setServiceConfigParser(mock(ServiceConfigParser.class)) - .setChannelLogger(mock(ChannelLogger.class)) - .setScheduledExecutorService(fakeExecutor.getScheduledExecutorService()) - .build(); - @Mock private NameResolver.Listener2 mockListener; @Captor private ArgumentCaptor resultCaptor; - @Captor - private ArgumentCaptor errorCaptor; - @Nullable - private String networkaddressCacheTtlPropertyValue; @Mock private RecordFetcher recordFetcher; + @Mock private ProxyDetector mockProxyDetector; private RetryingNameResolver newResolver(String name, int defaultPort) { return newResolver( name, defaultPort, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted()); } - private RetryingNameResolver newResolver(String name, int defaultPort, boolean isAndroid) { - return newResolver( - name, defaultPort, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted(), - isAndroid); - } - private RetryingNameResolver newResolver( String name, int defaultPort, @@ -208,59 +189,17 @@ private RetryingNameResolver newResolver( // In practice the DNS name resolver provider always wraps the resolver in a // RetryingNameResolver which adds retry capabilities to it. We use the same setup here. - return new RetryingNameResolver( - dnsResolver, - new BackoffPolicyRetryScheduler( - new ExponentialBackoffPolicy.Provider(), - fakeExecutor.getScheduledExecutorService(), - syncContext - ), - syncContext); + return (RetryingNameResolver) RetryingNameResolver.wrap(dnsResolver, args); } @Before public void setUp() { DnsNameResolver.enableJndi = true; - networkaddressCacheTtlPropertyValue = - System.getProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY); // By default the mock listener processes the result successfully. when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.OK); } - @After - public void restoreSystemProperty() { - if (networkaddressCacheTtlPropertyValue == null) { - System.clearProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY); - } else { - System.setProperty( - DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, - networkaddressCacheTtlPropertyValue); - } - } - - @Test - public void invalidDnsName() throws Exception { - testInvalidUri(new URI("dns", null, "/[invalid]", null)); - } - - @Test - public void validIpv6() throws Exception { - testValidUri(new URI("dns", null, "/[::1]", null), "[::1]", DEFAULT_PORT); - } - - @Test - public void validDnsNameWithoutPort() throws Exception { - testValidUri(new URI("dns", null, "/foo.googleapis.com", null), - "foo.googleapis.com", DEFAULT_PORT); - } - - @Test - public void validDnsNameWithPort() throws Exception { - testValidUri(new URI("dns", null, "/foo.googleapis.com:456", null), - "foo.googleapis.com:456", 456); - } - @Test public void nullDnsName() { try { @@ -281,30 +220,14 @@ public void invalidDnsName_containsUnderscore() { } } - @Test - public void resolve_androidIgnoresPropertyValue() throws Exception { - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, Long.toString(2)); - resolveNeverCache(true); - } - - @Test - public void resolve_androidIgnoresPropertyValueCacheForever() throws Exception { - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, Long.toString(-1)); - resolveNeverCache(true); - } - @Test public void resolve_neverCache() throws Exception { - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, "0"); - resolveNeverCache(false); - } - - private void resolveNeverCache(boolean isAndroid) throws Exception { + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, "0"); final List answer1 = createAddressList(2); final List answer2 = createAddressList(1); String name = "foo.googleapis.com"; - RetryingNameResolver resolver = newResolver(name, 81, isAndroid); + RetryingNameResolver resolver = newResolver(name, 81); DnsNameResolver dnsResolver = (DnsNameResolver) resolver.getRetriedNameResolver(); AddressResolver mockResolver = mock(AddressResolver.class); when(mockResolver.resolveAddress(anyString())).thenReturn(answer1).thenReturn(answer2); @@ -395,7 +318,7 @@ public void execute(Runnable command) { @Test public void resolve_cacheForever() throws Exception { - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, "-1"); + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, "-1"); final List answer1 = createAddressList(2); String name = "foo.googleapis.com"; FakeTicker fakeTicker = new FakeTicker(); @@ -429,7 +352,7 @@ public void resolve_cacheForever() throws Exception { @Test public void resolve_usingCache() throws Exception { long ttl = 60; - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, Long.toString(ttl)); + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, Long.toString(ttl)); final List answer = createAddressList(2); String name = "foo.googleapis.com"; FakeTicker fakeTicker = new FakeTicker(); @@ -464,7 +387,7 @@ public void resolve_usingCache() throws Exception { @Test public void resolve_cacheExpired() throws Exception { long ttl = 60; - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, Long.toString(ttl)); + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, Long.toString(ttl)); final List answer1 = createAddressList(2); final List answer2 = createAddressList(1); String name = "foo.googleapis.com"; @@ -497,26 +420,38 @@ public void resolve_cacheExpired() throws Exception { verify(mockResolver, times(2)).resolveAddress(anyString()); } + @Test + public void resolve_androidIgnoresPropertyValue() throws Exception { + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, "2"); + resolveDefaultValue(true); + } + + @Test + public void resolve_androidIgnoresPropertyValueCacheForever() throws Exception { + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, "-1"); + resolveDefaultValue(true); + } + @Test public void resolve_invalidTtlPropertyValue() throws Exception { - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, "not_a_number"); - resolveDefaultValue(); + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, "not_a_number"); + resolveDefaultValue(false); } @Test public void resolve_noPropertyValue() throws Exception { - System.clearProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY); - resolveDefaultValue(); + flagResetRule.clearSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY); + resolveDefaultValue(false); } - private void resolveDefaultValue() throws Exception { + private void resolveDefaultValue(boolean isAndroid) throws Exception { final List answer1 = createAddressList(2); final List answer2 = createAddressList(1); String name = "foo.googleapis.com"; FakeTicker fakeTicker = new FakeTicker(); RetryingNameResolver resolver = newResolver( - name, 81, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted(fakeTicker)); + name, 81, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted(fakeTicker), isAndroid); DnsNameResolver dnsResolver = (DnsNameResolver) resolver.getRetriedNameResolver(); AddressResolver mockResolver = mock(AddressResolver.class); when(mockResolver.resolveAddress(anyString())).thenReturn(answer1).thenReturn(answer2); @@ -570,7 +505,7 @@ public List resolveAddress(String host) throws Exception { ArgumentCaptor ac = ArgumentCaptor.forClass(ResolutionResult.class); verify(mockListener).onResult2(ac.capture()); verifyNoMoreInteractions(mockListener); - assertThat(ac.getValue().getAddresses()).isEmpty(); + assertThat(ac.getValue().getAddressesOrError().getValue()).isEmpty(); assertThat(ac.getValue().getServiceConfig()).isNull(); verify(mockResourceResolver, never()).resolveSrv(anyString()); @@ -578,6 +513,39 @@ public List resolveAddress(String host) throws Exception { assertEquals(0, fakeExecutor.numPendingTasks()); } + @Test + public void resolve_addressResolutionError() throws Exception { + DnsNameResolver.enableTxt = true; + when(mockProxyDetector.proxyFor(any(SocketAddress.class))).thenThrow(new IOException()); + RetryingNameResolver resolver = newResolver( + "addr.fake:1234", 443, mockProxyDetector, Stopwatch.createUnstarted()); + DnsNameResolver dnsResolver = (DnsNameResolver) resolver.getRetriedNameResolver(); + dnsResolver.setAddressResolver(new AddressResolver() { + @Override + public List resolveAddress(String host) throws Exception { + return Collections.emptyList(); + } + }); + ResourceResolver mockResourceResolver = mock(ResourceResolver.class); + when(mockResourceResolver.resolveTxt(anyString())) + .thenReturn(Collections.emptyList()); + + dnsResolver.setResourceResolver(mockResourceResolver); + + resolver.start(mockListener); + assertThat(fakeExecutor.runDueTasks()).isEqualTo(1); + + ArgumentCaptor ac = ArgumentCaptor.forClass(ResolutionResult.class); + verify(mockListener).onResult2(ac.capture()); + verifyNoMoreInteractions(mockListener); + assertThat(ac.getValue().getAddressesOrError().getStatus().getCode()).isEqualTo( + Status.UNAVAILABLE.getCode()); + assertThat(ac.getValue().getAddressesOrError().getStatus().getDescription()).isEqualTo( + "Unable to resolve host addr.fake"); + assertThat(ac.getValue().getAddressesOrError().getStatus().getCause()) + .isInstanceOf(IOException.class); + } + // Load balancer rejects the empty addresses. @Test public void resolve_emptyResult_notAccepted() throws Exception { @@ -604,7 +572,7 @@ public List resolveAddress(String host) throws Exception { ArgumentCaptor ac = ArgumentCaptor.forClass(ResolutionResult.class); verify(mockListener).onResult2(ac.capture()); verifyNoMoreInteractions(mockListener); - assertThat(ac.getValue().getAddresses()).isEmpty(); + assertThat(ac.getValue().getAddressesOrError().getValue()).isEmpty(); assertThat(ac.getValue().getServiceConfig()).isNull(); verify(mockResourceResolver, never()).resolveSrv(anyString()); @@ -632,7 +600,7 @@ public void resolve_nullResourceResolver() throws Exception { ResolutionResult result = resultCaptor.getValue(); InetSocketAddress resolvedBackendAddr = (InetSocketAddress) Iterables.getOnlyElement( - Iterables.getOnlyElement(result.getAddresses()).getAddresses()); + Iterables.getOnlyElement(result.getAddressesOrError().getValue()).getAddresses()); assertThat(resolvedBackendAddr.getAddress()).isEqualTo(backendAddr); verify(mockAddressResolver).resolveAddress(name); assertThat(result.getServiceConfig()).isNull(); @@ -647,6 +615,7 @@ public void resolve_nullResourceResolver_addressFailure() throws Exception { AddressResolver mockAddressResolver = mock(AddressResolver.class); when(mockAddressResolver.resolveAddress(anyString())) .thenThrow(new IOException("no addr")); + when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.UNAVAILABLE); String name = "foo.googleapis.com"; RetryingNameResolver resolver = newResolver(name, 81); @@ -655,8 +624,8 @@ public void resolve_nullResourceResolver_addressFailure() throws Exception { dnsResolver.setResourceResolver(null); resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onError(errorCaptor.capture()); - Status errorStatus = errorCaptor.getValue(); + verify(mockListener).onResult2(resultCaptor.capture()); + Status errorStatus = resultCaptor.getValue().getAddressesOrError().getStatus(); assertThat(errorStatus.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(errorStatus.getCause()).hasMessageThat().contains("no addr"); @@ -704,7 +673,7 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { ResolutionResult result = resultCaptor.getValue(); InetSocketAddress resolvedBackendAddr = (InetSocketAddress) Iterables.getOnlyElement( - Iterables.getOnlyElement(result.getAddresses()).getAddresses()); + Iterables.getOnlyElement(result.getAddressesOrError().getValue()).getAddresses()); assertThat(resolvedBackendAddr.getAddress()).isEqualTo(backendAddr); assertThat(result.getServiceConfig().getConfig()).isNotNull(); verify(mockAddressResolver).resolveAddress(name); @@ -715,11 +684,12 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { } @Test - public void resolve_addressFailure_neverLookUpServiceConfig() throws Exception { + public void resolve_addressFailure_stillLookUpServiceConfig() throws Exception { DnsNameResolver.enableTxt = true; AddressResolver mockAddressResolver = mock(AddressResolver.class); when(mockAddressResolver.resolveAddress(anyString())) .thenThrow(new IOException("no addr")); + when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.UNAVAILABLE); String name = "foo.googleapis.com"; ResourceResolver mockResourceResolver = mock(ResourceResolver.class); @@ -729,11 +699,11 @@ public void resolve_addressFailure_neverLookUpServiceConfig() throws Exception { dnsResolver.setResourceResolver(mockResourceResolver); resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onError(errorCaptor.capture()); - Status errorStatus = errorCaptor.getValue(); + verify(mockListener).onResult2(resultCaptor.capture()); + Status errorStatus = resultCaptor.getValue().getAddressesOrError().getStatus(); assertThat(errorStatus.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(errorStatus.getCause()).hasMessageThat().contains("no addr"); - verify(mockResourceResolver, never()).resolveTxt(anyString()); + verify(mockResourceResolver).resolveTxt("_grpc_config." + name); assertEquals(0, fakeClock.numPendingTasks()); // A retry should be scheduled @@ -762,7 +732,7 @@ public void resolve_serviceConfigLookupFails_nullServiceConfig() throws Exceptio ResolutionResult result = resultCaptor.getValue(); InetSocketAddress resolvedBackendAddr = (InetSocketAddress) Iterables.getOnlyElement( - Iterables.getOnlyElement(result.getAddresses()).getAddresses()); + Iterables.getOnlyElement(result.getAddressesOrError().getValue()).getAddresses()); assertThat(resolvedBackendAddr.getAddress()).isEqualTo(backendAddr); verify(mockAddressResolver).resolveAddress(name); assertThat(result.getServiceConfig()).isNull(); @@ -794,7 +764,7 @@ public void resolve_serviceConfigMalformed_serviceConfigError() throws Exception ResolutionResult result = resultCaptor.getValue(); InetSocketAddress resolvedBackendAddr = (InetSocketAddress) Iterables.getOnlyElement( - Iterables.getOnlyElement(result.getAddresses()).getAddresses()); + Iterables.getOnlyElement(result.getAddressesOrError().getValue()).getAddresses()); assertThat(resolvedBackendAddr.getAddress()).isEqualTo(backendAddr); verify(mockAddressResolver).resolveAddress(name); assertThat(result.getServiceConfig()).isNotNull(); @@ -859,7 +829,7 @@ public HttpConnectProxiedSocketAddress proxyFor(SocketAddress targetAddress) { assertEquals(1, fakeExecutor.runDueTasks()); verify(mockListener).onResult2(resultCaptor.capture()); - List result = resultCaptor.getValue().getAddresses(); + List result = resultCaptor.getValue().getAddressesOrError().getValue(); assertThat(result).hasSize(1); EquivalentAddressGroup eag = result.get(0); assertThat(eag.getAddresses()).hasSize(1); @@ -879,9 +849,10 @@ public HttpConnectProxiedSocketAddress proxyFor(SocketAddress targetAddress) { public void maybeChooseServiceConfig_failsOnMisspelling() { Map bad = new LinkedHashMap<>(); bad.put("parcentage", 1.0); - thrown.expectMessage("Bad key"); - - DnsNameResolver.maybeChooseServiceConfig(bad, new Random(), "host"); + Random random = new Random(); + VerifyException e = assertThrows(VerifyException.class, + () -> DnsNameResolver.maybeChooseServiceConfig(bad, random, "host")); + assertThat(e).hasMessageThat().isEqualTo("Bad key: parcentage=1.0"); } @Test @@ -1120,25 +1091,25 @@ public void parseTxtResults_misspelledName() throws Exception { } @Test - public void parseTxtResults_badTypeFails() throws Exception { + public void parseTxtResults_badTypeFails() { List txtRecords = new ArrayList<>(); txtRecords.add("some_record"); txtRecords.add("grpc_config={}"); - thrown.expect(ClassCastException.class); - thrown.expectMessage("wrong type"); - DnsNameResolver.parseTxtResults(txtRecords); + ClassCastException e = assertThrows(ClassCastException.class, + () -> DnsNameResolver.parseTxtResults(txtRecords)); + assertThat(e).hasMessageThat().isEqualTo("wrong type {}"); } @Test - public void parseTxtResults_badInnerTypeFails() throws Exception { + public void parseTxtResults_badInnerTypeFails() { List txtRecords = new ArrayList<>(); txtRecords.add("some_record"); txtRecords.add("grpc_config=[\"bogus\"]"); - thrown.expect(ClassCastException.class); - thrown.expectMessage("not object"); - DnsNameResolver.parseTxtResults(txtRecords); + ClassCastException e = assertThrows(ClassCastException.class, + () -> DnsNameResolver.parseTxtResults(txtRecords)); + assertThat(e).hasMessageThat().isEqualTo("value bogus for idx 0 in [bogus] is not object"); } @Test @@ -1271,22 +1242,6 @@ public void parseServiceConfig_matches() { assertThat(result.getConfig()).isEqualTo(ImmutableMap.of()); } - private void testInvalidUri(URI uri) { - try { - provider.newNameResolver(uri, args); - fail("Should have failed"); - } catch (IllegalArgumentException e) { - // expected - } - } - - private void testValidUri(URI uri, String exportedAuthority, int expectedPort) { - DnsNameResolver resolver = (DnsNameResolver) provider.newNameResolver(uri, args); - assertNotNull(resolver); - assertEquals(expectedPort, resolver.getPort()); - assertEquals(exportedAuthority, resolver.getServiceAuthority()); - } - private byte lastByte = 0; private List createAddressList(int n) throws UnknownHostException { @@ -1299,9 +1254,9 @@ private List createAddressList(int n) throws UnknownHostException { private static void assertAnswerMatches( List addrs, int port, ResolutionResult resolutionResult) { - assertThat(resolutionResult.getAddresses()).hasSize(addrs.size()); + assertThat(resolutionResult.getAddressesOrError().getValue()).hasSize(addrs.size()); for (int i = 0; i < addrs.size(); i++) { - EquivalentAddressGroup addrGroup = resolutionResult.getAddresses().get(i); + EquivalentAddressGroup addrGroup = resolutionResult.getAddressesOrError().getValue().get(i); InetSocketAddress socketAddr = (InetSocketAddress) Iterables.getOnlyElement(addrGroup.getAddresses()); assertEquals("Addr " + i, port, socketAddr.getPort()); diff --git a/core/src/test/java/io/grpc/internal/ForwardingReadableBufferTest.java b/core/src/test/java/io/grpc/internal/ForwardingReadableBufferTest.java index 8ce45bc77cf..696fb35e379 100644 --- a/core/src/test/java/io/grpc/internal/ForwardingReadableBufferTest.java +++ b/core/src/test/java/io/grpc/internal/ForwardingReadableBufferTest.java @@ -25,7 +25,6 @@ import java.io.IOException; import java.io.OutputStream; import java.lang.reflect.Method; -import java.nio.ByteBuffer; import java.util.Collections; import org.junit.Before; import org.junit.Rule; @@ -91,14 +90,6 @@ public void readBytes() { verify(delegate).readBytes(dest, 1, 2); } - @Test - public void readBytes_overload1() { - ByteBuffer dest = ByteBuffer.allocate(0); - buffer.readBytes(dest); - - verify(delegate).readBytes(dest); - } - @Test public void readBytes_overload2() throws IOException { OutputStream dest = mock(OutputStream.class); diff --git a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java index 39acb582d28..c243790028c 100644 --- a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java +++ b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -41,7 +42,6 @@ import java.util.ArrayList; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -57,8 +57,6 @@ public class GrpcUtilTest { new ClientStreamTracer() {} }; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Captor @@ -100,8 +98,8 @@ public void timeoutTest() { GrpcUtil.TimeoutMarshaller marshaller = new GrpcUtil.TimeoutMarshaller(); // nanos - assertEquals("0n", marshaller.toAsciiString(0L)); - assertEquals(0L, (long) marshaller.parseAsciiString("0n")); + assertEquals("1n", marshaller.toAsciiString(1L)); + assertEquals(1L, (long) marshaller.parseAsciiString("1n")); assertEquals("99999999n", marshaller.toAsciiString(99999999L)); assertEquals(99999999L, (long) marshaller.parseAsciiString("99999999n")); @@ -201,9 +199,7 @@ public void urlAuthorityEscape_unicodeAreNotEncoded() { @Test public void checkAuthority_failsOnNull() { - thrown.expect(NullPointerException.class); - - GrpcUtil.checkAuthority(null); + assertThrows(NullPointerException.class, () -> GrpcUtil.checkAuthority(null)); } @Test @@ -229,19 +225,18 @@ public void checkAuthority_succeedsOnIpV6() { @Test public void checkAuthority_failsOnInvalidAuthority() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority"); - - GrpcUtil.checkAuthority("[ : : 1]"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> GrpcUtil.checkAuthority("[ : : 1]")); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [ : : 1]"); } @Test public void checkAuthority_userInfoNotAllowed() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Userinfo"); - - GrpcUtil.checkAuthority("foo@valid"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> GrpcUtil.checkAuthority("foo@valid")); + assertThat(e).hasMessageThat() + .isEqualTo("Userinfo must not be present on authority: 'foo@valid'"); } @Test diff --git a/core/src/test/java/io/grpc/internal/InstantTimeProviderTest.java b/core/src/test/java/io/grpc/internal/InstantTimeProviderTest.java new file mode 100644 index 00000000000..6702bc421a5 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/InstantTimeProviderTest.java @@ -0,0 +1,51 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.truth.Truth.assertThat; + +import java.time.Instant; +import java.util.concurrent.TimeUnit; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link InstantTimeProvider}. + */ +@RunWith(JUnit4.class) +@IgnoreJRERequirement +public class InstantTimeProviderTest { + @Test + public void testInstantCurrentTimeNanos() throws Exception { + + InstantTimeProvider instantTimeProvider = new InstantTimeProvider(); + + // Get the current time from the InstantTimeProvider + long actualTimeNanos = instantTimeProvider.currentTimeNanos(); + + // Get the current time from Instant for comparison + Instant instantNow = Instant.now(); + long expectedTimeNanos = TimeUnit.SECONDS.toNanos(instantNow.getEpochSecond()) + + instantNow.getNano(); + + // Validate the time returned is close to the expected value within a tolerance + // (i,e 1000 millisecond (1 second) tolerance in nanoseconds). + assertThat(actualTimeNanos).isWithin(1000_000_000L).of(expectedTimeNanos); + } +} diff --git a/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java b/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java index e4d9f27ed46..811344da307 100644 --- a/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java +++ b/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java @@ -27,11 +27,15 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -46,6 +50,11 @@ import io.grpc.InternalChannelz; import io.grpc.InternalLogId; import io.grpc.InternalWithLogId; +import io.grpc.LoadBalancer; +import io.grpc.MetricInstrument; +import io.grpc.MetricRecorder; +import io.grpc.NameResolver; +import io.grpc.SecurityLevel; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.internal.InternalSubchannel.CallTracingTransport; @@ -64,9 +73,9 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -78,11 +87,11 @@ public class InternalSubchannelTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); private static final String AUTHORITY = "fakeauthority"; + private static final String BACKEND_SERVICE = "ice-cream-factory-service"; + private static final String LOCALITY = "mars-olympus-mons-datacenter"; + private static final SecurityLevel SECURITY_LEVEL = SecurityLevel.PRIVACY_AND_INTEGRITY; private static final String USER_AGENT = "mosaic"; private static final ConnectivityStateInfo UNAVAILABLE_STATE = ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE); @@ -110,6 +119,10 @@ public void uncaughtException(Thread t, Throwable e) { @Mock private BackoffPolicy.Provider mockBackoffPolicyProvider; @Mock private ClientTransportFactory mockTransportFactory; + @Mock private BackoffPolicy mockBackoffPolicy; + private MetricRecorder mockMetricRecorder = mock(MetricRecorder.class, + delegatesTo(new MetricRecorderImpl())); + private final LinkedList callbackInvokes = new LinkedList<>(); private final InternalSubchannel.Callback mockInternalSubchannelCallback = new InternalSubchannel.Callback() { @@ -220,7 +233,8 @@ public void constructor_eagListWithNull_throws() { // Fail this one. Because there is only one address to try, enter TRANSIENT_FAILURE. assertNoCallbackInvoke(); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(TRANSIENT_FAILURE, internalSubchannel.getState()); assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); // Backoff reset and using first back-off value interval @@ -251,7 +265,8 @@ public void constructor_eagListWithNull_throws() { assertNoCallbackInvoke(); // Here we use a different status from the first failure, and verify that it's passed to // the callback. - transports.poll().listener.transportShutdown(Status.RESOURCE_EXHAUSTED); + transports.poll().listener.transportShutdown(Status.RESOURCE_EXHAUSTED, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(TRANSIENT_FAILURE, internalSubchannel.getState()); assertExactCallbackInvokes("onStateChange:" + RESOURCE_EXHAUSTED_STATE); // Second back-off interval @@ -289,7 +304,8 @@ public void constructor_eagListWithNull_throws() { // Close the READY transport, will enter IDLE state. assertNoCallbackInvoke(); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(IDLE, internalSubchannel.getState()); assertExactCallbackInvokes("onStateChange:IDLE"); @@ -309,10 +325,59 @@ public void constructor_eagListWithNull_throws() { verify(mockBackoffPolicy2, times(backoff2Consulted)).nextBackoffNanos(); } + @Test public void twoAddressesReconnectDisabled() { + SocketAddress addr1 = mock(SocketAddress.class); + SocketAddress addr2 = mock(SocketAddress.class); + createInternalSubchannel(true, + new EquivalentAddressGroup(Arrays.asList(addr1, addr2))); + assertEquals(IDLE, internalSubchannel.getState()); + + assertNull(internalSubchannel.obtainActiveTransport()); + assertExactCallbackInvokes("onStateChange:CONNECTING"); + assertEquals(CONNECTING, internalSubchannel.getState()); + verify(mockTransportFactory).newClientTransport(eq(addr1), any(), any()); + // Let this one fail without success + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + // Still in CONNECTING + assertNull(internalSubchannel.obtainActiveTransport()); + assertNoCallbackInvoke(); + assertEquals(CONNECTING, internalSubchannel.getState()); + + // Second attempt will start immediately. Still no back-off policy. + verify(mockBackoffPolicyProvider, times(0)).get(); + verify(mockTransportFactory, times(1)) + .newClientTransport( + eq(addr2), + eq(createClientTransportOptions()), + isA(TransportLogger.class)); + assertNull(internalSubchannel.obtainActiveTransport()); + // Fail this one too + assertNoCallbackInvoke(); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + // All addresses have failed, but we aren't controlling retries. + assertEquals(IDLE, internalSubchannel.getState()); + assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); + // Backoff reset and first back-off interval begins + verify(mockBackoffPolicy1, never()).nextBackoffNanos(); + verify(mockBackoffPolicyProvider, never()).get(); + assertTrue("Nothing should have been scheduled", fakeClock.getPendingTasks().isEmpty()); + + // Should follow orders and create an active transport. + internalSubchannel.obtainActiveTransport(); + assertExactCallbackInvokes("onStateChange:CONNECTING"); + assertEquals(CONNECTING, internalSubchannel.getState()); + + // Shouldn't have anything scheduled, so shouldn't do anything + assertTrue("Nothing should have been scheduled 2", fakeClock.getPendingTasks().isEmpty()); + } + @Test public void twoAddressesReconnect() { SocketAddress addr1 = mock(SocketAddress.class); SocketAddress addr2 = mock(SocketAddress.class); - createInternalSubchannel(addr1, addr2); + createInternalSubchannel(false, + new EquivalentAddressGroup(Arrays.asList(addr1, addr2))); assertEquals(IDLE, internalSubchannel.getState()); // Invocation counters int transportsAddr1 = 0; @@ -334,7 +399,8 @@ public void constructor_eagListWithNull_throws() { isA(TransportLogger.class)); // Let this one fail without success - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // Still in CONNECTING assertNull(internalSubchannel.obtainActiveTransport()); assertNoCallbackInvoke(); @@ -350,7 +416,8 @@ public void constructor_eagListWithNull_throws() { assertNull(internalSubchannel.obtainActiveTransport()); // Fail this one too assertNoCallbackInvoke(); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // All addresses have failed. Delayed transport will be in back-off interval. assertEquals(TRANSIENT_FAILURE, internalSubchannel.getState()); assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); @@ -381,7 +448,8 @@ public void constructor_eagListWithNull_throws() { eq(createClientTransportOptions()), isA(TransportLogger.class)); // Fail this one too - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(CONNECTING, internalSubchannel.getState()); // Forth attempt will start immediately. Keep back-off policy. @@ -395,7 +463,8 @@ public void constructor_eagListWithNull_throws() { isA(TransportLogger.class)); // Fail this one too assertNoCallbackInvoke(); - transports.poll().listener.transportShutdown(Status.RESOURCE_EXHAUSTED); + transports.poll().listener.transportShutdown(Status.RESOURCE_EXHAUSTED, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // All addresses have failed again. Delayed transport will be in back-off interval. assertExactCallbackInvokes("onStateChange:" + RESOURCE_EXHAUSTED_STATE); assertEquals(TRANSIENT_FAILURE, internalSubchannel.getState()); @@ -432,7 +501,8 @@ public void constructor_eagListWithNull_throws() { ((CallTracingTransport) internalSubchannel.obtainActiveTransport()).delegate()); // Then close it. assertNoCallbackInvoke(); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); assertEquals(IDLE, internalSubchannel.getState()); @@ -448,7 +518,8 @@ public void constructor_eagListWithNull_throws() { eq(createClientTransportOptions()), isA(TransportLogger.class)); // Fail the transport - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(CONNECTING, internalSubchannel.getState()); // Second attempt will start immediately. Still no new back-off policy. @@ -460,7 +531,8 @@ public void constructor_eagListWithNull_throws() { isA(TransportLogger.class)); // Fail this one too assertEquals(CONNECTING, internalSubchannel.getState()); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // All addresses have failed. Enter TRANSIENT_FAILURE. Back-off in effect. assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); assertEquals(TRANSIENT_FAILURE, internalSubchannel.getState()); @@ -496,8 +568,9 @@ public void constructor_eagListWithNull_throws() { public void updateAddresses_emptyEagList_throws() { SocketAddress addr = new FakeSocketAddress(); createInternalSubchannel(addr); - thrown.expect(IllegalArgumentException.class); - internalSubchannel.updateAddresses(Arrays.asList()); + List newAddressGroups = Collections.emptyList(); + assertThrows(IllegalArgumentException.class, + () -> internalSubchannel.updateAddresses(newAddressGroups)); } @Test @@ -505,8 +578,7 @@ public void updateAddresses_eagListWithNull_throws() { SocketAddress addr = new FakeSocketAddress(); createInternalSubchannel(addr); List eags = Arrays.asList((EquivalentAddressGroup) null); - thrown.expect(NullPointerException.class); - internalSubchannel.updateAddresses(eags); + assertThrows(NullPointerException.class, () -> internalSubchannel.updateAddresses(eags)); } @Test public void updateAddresses_intersecting_ready() { @@ -524,7 +596,8 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr1), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(CONNECTING, internalSubchannel.getState()); // Second address connects @@ -546,7 +619,8 @@ public void updateAddresses_eagListWithNull_throws() { verify(transports.peek().transport, never()).shutdownNow(any(Status.class)); // And new addresses chosen when re-connecting - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); assertNull(internalSubchannel.obtainActiveTransport()); @@ -556,13 +630,15 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr2), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verify(mockTransportFactory) .newClientTransport( eq(addr3), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verifyNoMoreInteractions(mockTransportFactory); fakeClock.forwardNanos(10); // Drain retry, but don't care about result @@ -583,7 +659,8 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr1), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(CONNECTING, internalSubchannel.getState()); // Second address connecting @@ -606,7 +683,8 @@ public void updateAddresses_eagListWithNull_throws() { // And new addresses chosen when re-connecting transports.peek().listener.transportReady(); assertExactCallbackInvokes("onStateChange:READY"); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); assertNull(internalSubchannel.obtainActiveTransport()); @@ -616,13 +694,15 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr2), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verify(mockTransportFactory) .newClientTransport( eq(addr3), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verifyNoMoreInteractions(mockTransportFactory); fakeClock.forwardNanos(10); // Drain retry, but don't care about result @@ -661,7 +741,8 @@ public void updateAddresses_eagListWithNull_throws() { // And no other addresses attempted assertEquals(0, fakeClock.numPendingTasks()); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); assertEquals(TRANSIENT_FAILURE, internalSubchannel.getState()); verifyNoMoreInteractions(mockTransportFactory); @@ -685,7 +766,8 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr1), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(CONNECTING, internalSubchannel.getState()); // Second address connects @@ -709,7 +791,8 @@ public void updateAddresses_eagListWithNull_throws() { verify(transports.peek().transport).shutdown(any(Status.class)); // And new addresses chosen when re-connecting - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertNoCallbackInvoke(); assertEquals(IDLE, internalSubchannel.getState()); @@ -720,13 +803,15 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr3), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verify(mockTransportFactory) .newClientTransport( eq(addr4), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verifyNoMoreInteractions(mockTransportFactory); fakeClock.forwardNanos(10); // Drain retry, but don't care about result @@ -748,7 +833,8 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr1), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(CONNECTING, internalSubchannel.getState()); // Second address connecting @@ -778,13 +864,15 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr3), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verify(mockTransportFactory) .newClientTransport( eq(addr4), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verifyNoMoreInteractions(mockTransportFactory); fakeClock.forwardNanos(10); // Drain retry, but don't care about result @@ -868,7 +956,8 @@ public void connectIsLazy() { isA(TransportLogger.class)); // Fail this one - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); // Will always reconnect after back-off @@ -884,7 +973,8 @@ public void connectIsLazy() { transports.peek().listener.transportReady(); assertExactCallbackInvokes("onStateChange:READY"); // Then go-away - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); // No scheduled tasks that would ever try to reconnect ... @@ -914,7 +1004,8 @@ public void shutdownWhenReady() throws Exception { internalSubchannel.shutdown(SHUTDOWN_REASON); verify(transportInfo.transport).shutdown(same(SHUTDOWN_REASON)); assertExactCallbackInvokes("onStateChange:SHUTDOWN"); - transportInfo.listener.transportShutdown(SHUTDOWN_REASON); + transportInfo.listener.transportShutdown(SHUTDOWN_REASON, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo.listener.transportTerminated(); assertExactCallbackInvokes("onTerminated"); @@ -937,7 +1028,8 @@ public void shutdownBeforeTransportCreated() throws Exception { // Fail this one MockClientTransportInfo transportInfo = transports.poll(); - transportInfo.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo.listener.transportTerminated(); // Entering TRANSIENT_FAILURE, waiting for back-off @@ -993,7 +1085,8 @@ public void shutdownBeforeTransportReady() throws Exception { // The transport should've been shut down even though it's not the active transport yet. verify(transportInfo.transport).shutdown(same(SHUTDOWN_REASON)); - transportInfo.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertNoCallbackInvoke(); transportInfo.listener.transportTerminated(); assertExactCallbackInvokes("onTerminated"); @@ -1009,7 +1102,7 @@ public void shutdownNow() throws Exception { MockClientTransportInfo t1 = transports.poll(); t1.listener.transportReady(); assertExactCallbackInvokes("onStateChange:CONNECTING", "onStateChange:READY"); - t1.listener.transportShutdown(Status.UNAVAILABLE); + t1.listener.transportShutdown(Status.UNAVAILABLE, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); internalSubchannel.obtainActiveTransport(); @@ -1066,7 +1159,7 @@ public void inUseState() { t0.listener.transportInUse(true); assertExactCallbackInvokes("onInUse"); - t0.listener.transportShutdown(Status.UNAVAILABLE); + t0.listener.transportShutdown(Status.UNAVAILABLE, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); assertNull(internalSubchannel.obtainActiveTransport()); @@ -1099,7 +1192,7 @@ public void transportTerminateWithoutExitingInUse() { t0.listener.transportInUse(true); assertExactCallbackInvokes("onInUse"); - t0.listener.transportShutdown(Status.UNAVAILABLE); + t0.listener.transportShutdown(Status.UNAVAILABLE, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); t0.listener.transportTerminated(); assertExactCallbackInvokes("onNotInUse"); @@ -1126,12 +1219,12 @@ public void run() { assertEquals(1, runnableInvokes.get()); MockClientTransportInfo t0 = transports.poll(); - t0.listener.transportShutdown(Status.UNAVAILABLE); + t0.listener.transportShutdown(Status.UNAVAILABLE, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(2, runnableInvokes.get()); // 2nd address: reconnect immediatly MockClientTransportInfo t1 = transports.poll(); - t1.listener.transportShutdown(Status.UNAVAILABLE); + t1.listener.transportShutdown(Status.UNAVAILABLE, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // Addresses exhausted, waiting for back-off. assertEquals(2, runnableInvokes.get()); @@ -1158,7 +1251,8 @@ public void resetConnectBackoff() throws Exception { eq(addr), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); // Save the reconnectTask @@ -1194,7 +1288,8 @@ public void resetConnectBackoff() throws Exception { // Fail the reconnect attempt to verify that a fresh reconnect policy is generated after // invoking resetConnectBackoff() - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); verify(mockBackoffPolicyProvider, times(2)).get(); fakeClock.forwardNanos(10); @@ -1222,7 +1317,8 @@ public void channelzMembership() throws Exception { MockClientTransportInfo t0 = transports.poll(); t0.listener.transportReady(); assertTrue(channelz.containsClientSocket(t0.transport.getLogId())); - t0.listener.transportShutdown(Status.RESOURCE_EXHAUSTED); + t0.listener.transportShutdown(Status.RESOURCE_EXHAUSTED, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); t0.listener.transportTerminated(); assertFalse(channelz.containsClientSocket(t0.transport.getLogId())); } @@ -1377,18 +1473,162 @@ private void createInternalSubchannel(SocketAddress ... addrs) { } private void createInternalSubchannel(EquivalentAddressGroup ... addrs) { + createInternalSubchannel(false, addrs); + } + + private void createInternalSubchannel(boolean reconnectDisabled, + EquivalentAddressGroup ... addrs) { List addressGroups = Arrays.asList(addrs); InternalLogId logId = InternalLogId.allocate("Subchannel", /*details=*/ AUTHORITY); ChannelTracer subchannelTracer = new ChannelTracer(logId, 10, fakeClock.getTimeProvider().currentTimeNanos(), "Subchannel"); - internalSubchannel = new InternalSubchannel(addressGroups, AUTHORITY, USER_AGENT, + LoadBalancer.CreateSubchannelArgs.Builder argBuilder = + LoadBalancer.CreateSubchannelArgs.newBuilder().setAddresses(addressGroups); + if (reconnectDisabled) { + argBuilder.addOption(LoadBalancer.DISABLE_SUBCHANNEL_RECONNECT_KEY, reconnectDisabled); + } + LoadBalancer.CreateSubchannelArgs createSubchannelArgs = argBuilder.build(); + internalSubchannel = new InternalSubchannel( + createSubchannelArgs, + AUTHORITY, USER_AGENT, mockBackoffPolicyProvider, mockTransportFactory, fakeClock.getScheduledExecutorService(), fakeClock.getStopwatchSupplier(), syncContext, mockInternalSubchannelCallback, channelz, CallTracer.getDefaultFactory().create(), subchannelTracer, logId, new ChannelLoggerImpl(subchannelTracer, fakeClock.getTimeProvider()), - Collections.emptyList()); + Collections.emptyList(), + "", + new MetricRecorder() { + } + ); + } + + @Test + public void subchannelStateChanges_triggersAttemptFailedMetric() { + // 1. Setup: Standard subchannel initialization + when(mockBackoffPolicyProvider.get()).thenReturn(mockBackoffPolicy); + SocketAddress addr = mock(SocketAddress.class); + Attributes eagAttributes = Attributes.newBuilder() + .set(NameResolver.ATTR_BACKEND_SERVICE, BACKEND_SERVICE) + .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, LOCALITY) + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SECURITY_LEVEL) + .build(); + List addressGroups = + Arrays.asList(new EquivalentAddressGroup(Arrays.asList(addr), eagAttributes)); + InternalLogId logId = InternalLogId.allocate("Subchannel", /*details=*/ AUTHORITY); + ChannelTracer subchannelTracer = new ChannelTracer(logId, 10, + fakeClock.getTimeProvider().currentTimeNanos(), "Subchannel"); + LoadBalancer.CreateSubchannelArgs createSubchannelArgs = + LoadBalancer.CreateSubchannelArgs.newBuilder().setAddresses(addressGroups).build(); + internalSubchannel = new InternalSubchannel( + createSubchannelArgs, AUTHORITY, USER_AGENT, mockBackoffPolicyProvider, + mockTransportFactory, fakeClock.getScheduledExecutorService(), + fakeClock.getStopwatchSupplier(), syncContext, mockInternalSubchannelCallback, channelz, + CallTracer.getDefaultFactory().create(), subchannelTracer, logId, + new ChannelLoggerImpl(subchannelTracer, fakeClock.getTimeProvider()), + Collections.emptyList(), AUTHORITY, mockMetricRecorder + ); + + // --- Action: Simulate the "connecting to failed" transition --- + // a. Initiate the connection attempt. The subchannel is now CONNECTING. + internalSubchannel.obtainActiveTransport(); + MockClientTransportInfo transportInfo = transports.poll(); + assertNotNull("A connection attempt should have been made", transportInfo); + + // b. Fail the transport before it can signal `transportReady()`. + transportInfo.listener.transportShutdown( + Status.INTERNAL.withDescription("Simulated connect failure"), + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + fakeClock.runDueTasks(); // Process the failure event + + // --- Verification --- + // a. Verify that the "connection_attempts_failed" metric was recorded exactly once. + verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.subchannel.connection_attempts_failed"), + eq(1L), + eq(Arrays.asList(AUTHORITY)), + eq(Arrays.asList(BACKEND_SERVICE, LOCALITY)) + ); + + // b. Verify no other metrics were recorded. This confirms it wasn't incorrectly + // logged as a success, disconnection, or open connection. + verifyNoMoreInteractions(mockMetricRecorder); + } + + @Test + public void subchannelStateChanges_triggersSuccessAndDisconnectMetrics() { + // 1. Mock the backoff policy (needed for subchannel creation) + when(mockBackoffPolicyProvider.get()).thenReturn(mockBackoffPolicy); + + // 2. Setup Subchannel with attributes + SocketAddress addr = mock(SocketAddress.class); + Attributes eagAttributes = Attributes.newBuilder() + .set(NameResolver.ATTR_BACKEND_SERVICE, BACKEND_SERVICE) + .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, LOCALITY) + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SECURITY_LEVEL) + .build(); + List addressGroups = + Arrays.asList(new EquivalentAddressGroup(Arrays.asList(addr), eagAttributes)); + createInternalSubchannel(new EquivalentAddressGroup(addr)); + InternalLogId logId = InternalLogId.allocate("Subchannel", /*details=*/ AUTHORITY); + ChannelTracer subchannelTracer = new ChannelTracer(logId, 10, + fakeClock.getTimeProvider().currentTimeNanos(), "Subchannel"); + LoadBalancer.CreateSubchannelArgs createSubchannelArgs = + LoadBalancer.CreateSubchannelArgs.newBuilder().setAddresses(addressGroups).build(); + internalSubchannel = new InternalSubchannel( + createSubchannelArgs, AUTHORITY, USER_AGENT, mockBackoffPolicyProvider, + mockTransportFactory, fakeClock.getScheduledExecutorService(), + fakeClock.getStopwatchSupplier(), syncContext, mockInternalSubchannelCallback, channelz, + CallTracer.getDefaultFactory().create(), subchannelTracer, logId, + new ChannelLoggerImpl(subchannelTracer, fakeClock.getTimeProvider()), + Collections.emptyList(), AUTHORITY, mockMetricRecorder + ); + + // --- Action: Successful connection --- + internalSubchannel.obtainActiveTransport(); + MockClientTransportInfo transportInfo = transports.poll(); + assertNotNull(transportInfo); + transportInfo.listener.transportReady(); + fakeClock.runDueTasks(); // Process the successful connection + + // --- Action: Transport is shut down --- + transportInfo.listener.transportShutdown(Status.UNAVAILABLE.withDescription("unknown"), + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + fakeClock.runDueTasks(); // Process the shutdown + + // --- Verification --- + InOrder inOrder = inOrder(mockMetricRecorder); + + // Verify successful connection metrics + inOrder.verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.subchannel.connection_attempts_succeeded"), + eq(1L), + eq(Arrays.asList(AUTHORITY)), + eq(Arrays.asList(BACKEND_SERVICE, LOCALITY)) + ); + inOrder.verify(mockMetricRecorder).addLongUpDownCounter( + eqMetricInstrumentName("grpc.subchannel.open_connections"), + eq(1L), + eq(Arrays.asList(AUTHORITY)), + eq(Arrays.asList("privacy_and_integrity", BACKEND_SERVICE, LOCALITY)) + ); + + // Verify disconnection metrics + inOrder.verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.subchannel.disconnections"), + eq(1L), + eq(Arrays.asList(AUTHORITY)), + eq(Arrays.asList(BACKEND_SERVICE, LOCALITY, "subchannel shutdown")) + ); + inOrder.verify(mockMetricRecorder).addLongUpDownCounter( + eqMetricInstrumentName("grpc.subchannel.open_connections"), + eq(-1L), + eq(Arrays.asList(AUTHORITY)), + eq(Arrays.asList("privacy_and_integrity", BACKEND_SERVICE, LOCALITY)) + ); + + inOrder.verifyNoMoreInteractions(); } private void assertNoCallbackInvoke() { @@ -1401,5 +1641,13 @@ private void assertExactCallbackInvokes(String ... expectedInvokes) { callbackInvokes.clear(); } + static class MetricRecorderImpl implements MetricRecorder { + } + + @SuppressWarnings("TypeParameterUnusedInFormals") + private T eqMetricInstrumentName(String name) { + return argThat(instrument -> instrument.getName().equals(name)); + } + private static class FakeSocketAddress extends SocketAddress {} } diff --git a/core/src/test/java/io/grpc/internal/JsonParserTest.java b/core/src/test/java/io/grpc/internal/JsonParserTest.java index 1e74c753d4d..a0dd81c20ce 100644 --- a/core/src/test/java/io/grpc/internal/JsonParserTest.java +++ b/core/src/test/java/io/grpc/internal/JsonParserTest.java @@ -17,15 +17,14 @@ package io.grpc.internal; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import com.google.gson.stream.MalformedJsonException; import java.io.EOFException; import java.io.IOException; import java.util.ArrayList; import java.util.LinkedHashMap; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -35,10 +34,6 @@ @RunWith(JUnit4.class) public class JsonParserTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void emptyObject() throws IOException { assertEquals(new LinkedHashMap(), JsonParser.parse("{}")); @@ -75,45 +70,33 @@ public void nullValue() throws IOException { } @Test - public void nanFails() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("NaN"); + public void nanFails() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("NaN")); } @Test - public void objectEarlyEnd() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("{foo:}"); + public void objectEarlyEnd() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("{foo:}")); } @Test - public void earlyEndArray() throws IOException { - thrown.expect(EOFException.class); - - JsonParser.parse("[1, 2, "); + public void earlyEndArray() { + assertThrows(EOFException.class, () -> JsonParser.parse("[1, 2, ")); } @Test - public void arrayMissingElement() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("[1, 2, ]"); + public void arrayMissingElement() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("[1, 2, ]")); } @Test - public void objectMissingElement() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("{1: "); + public void objectMissingElement() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("{1: ")); } @Test - public void objectNoName() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("{: 1"); + public void objectNoName() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("{: 1")); } @Test @@ -123,4 +106,9 @@ public void objectStringName() throws IOException { assertEquals(expected, JsonParser.parse("{\"hi\": 2}")); } + + @Test + public void duplicate() { + assertThrows(IllegalArgumentException.class, () -> JsonParser.parse("{\"hi\": 2, \"hi\": 3}")); + } } diff --git a/core/src/test/java/io/grpc/internal/KeepAliveManagerTest.java b/core/src/test/java/io/grpc/internal/KeepAliveManagerTest.java index 411a9fbe9fc..81e3d1b2638 100644 --- a/core/src/test/java/io/grpc/internal/KeepAliveManagerTest.java +++ b/core/src/test/java/io/grpc/internal/KeepAliveManagerTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -104,13 +105,15 @@ public void keepAlivePingDelayedByIncomingData() { @Test public void clientKeepAlivePinger_pingTimeout() { - ConnectionClientTransport transport = mock(ConnectionClientTransport.class); + ClientKeepAlivePinger.TransportWithDisconnectReason transport = + mock(ClientKeepAlivePinger.TransportWithDisconnectReason.class); ClientKeepAlivePinger pinger = new ClientKeepAlivePinger(transport); pinger.onPingTimeout(); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(transport).shutdownNow(statusCaptor.capture()); + verify(transport).shutdownNow(statusCaptor.capture(), + eq(SimpleDisconnectError.CONNECTION_TIMED_OUT)); Status status = statusCaptor.getValue(); assertThat(status.getCode()).isEqualTo(Status.Code.UNAVAILABLE); assertThat(status.getDescription()).isEqualTo( @@ -119,7 +122,8 @@ public void clientKeepAlivePinger_pingTimeout() { @Test public void clientKeepAlivePinger_pingFailure() { - ConnectionClientTransport transport = mock(ConnectionClientTransport.class); + ClientKeepAlivePinger.TransportWithDisconnectReason transport = + mock(ClientKeepAlivePinger.TransportWithDisconnectReason.class); ClientKeepAlivePinger pinger = new ClientKeepAlivePinger(transport); pinger.ping(); ArgumentCaptor pingCallbackCaptor = @@ -127,10 +131,11 @@ public void clientKeepAlivePinger_pingFailure() { verify(transport).ping(pingCallbackCaptor.capture(), isA(Executor.class)); ClientTransport.PingCallback pingCallback = pingCallbackCaptor.getValue(); - pingCallback.onFailure(new Throwable()); + pingCallback.onFailure(Status.UNAVAILABLE.withDescription("I must write descriptions")); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(transport).shutdownNow(statusCaptor.capture()); + verify(transport).shutdownNow(statusCaptor.capture(), + eq(SimpleDisconnectError.CONNECTION_TIMED_OUT)); Status status = statusCaptor.getValue(); assertThat(status.getCode()).isEqualTo(Status.Code.UNAVAILABLE); assertThat(status.getDescription()).isEqualTo( diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java index ddef7ef546f..b0939239477 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.doReturn; @@ -37,8 +38,10 @@ import io.grpc.ClientInterceptor; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; +import io.grpc.FlagResetRule; import io.grpc.InternalConfigurator; import io.grpc.InternalConfiguratorRegistry; +import io.grpc.InternalFeatureFlags; import io.grpc.InternalManagedChannelBuilder.InternalInterceptorFactory; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; @@ -67,15 +70,16 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; /** Unit tests for {@link ManagedChannelImplBuilder}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class ManagedChannelImplBuilderTest { private static final int DUMMY_PORT = 42; private static final String DUMMY_TARGET = "fake-target"; @@ -98,10 +102,16 @@ public ClientCall interceptCall( } }; + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + @Rule public final FlagResetRule flagResetRule = new FlagResetRule(); @Mock private ClientTransportFactory mockClientTransportFactory; @Mock private ClientTransportFactoryBuilder mockClientTransportFactoryBuilder; @@ -119,6 +129,9 @@ public ClientCall interceptCall( @Before public void setUp() throws Exception { + flagResetRule.setFlagForTest( + InternalFeatureFlags::setRfc3986UrisEnabled, enableRfc3986UrisParam); + builder = new ManagedChannelImplBuilder( DUMMY_TARGET, new UnsupportedClientTransportFactoryBuilder(), @@ -375,8 +388,11 @@ public void transportDoesNotSupportAddressTypes() { ManagedChannel unused = grpcCleanupRule.register(builder.build()); fail("Should fail"); } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().isEqualTo( - "Address types of NameResolver 'dns' for 'valid:1234' not supported by transport"); + assertThat(e) + .hasMessageThat() + .isEqualTo( + "Address types of NameResolver 'dns' for 'dns:///valid:1234' not supported by" + + " transport"); } } @@ -424,10 +440,9 @@ public void checkAuthority_validAuthorityAllowed() { @Test public void checkAuthority_invalidAuthorityFailed() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority"); - - builder.checkAuthority(DUMMY_AUTHORITY_INVALID); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.checkAuthority(DUMMY_AUTHORITY_INVALID)); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [ : : 1]"); } @Test @@ -450,11 +465,10 @@ public void enableCheckAuthority_validAuthorityAllowed() { @Test public void disableCheckAuthority_invalidAuthorityFailed() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority"); - builder.disableCheckAuthority().enableCheckAuthority(); - builder.checkAuthority(DUMMY_AUTHORITY_INVALID); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.checkAuthority(DUMMY_AUTHORITY_INVALID)); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [ : : 1]"); } @Test @@ -533,12 +547,9 @@ public void run() { List effectiveInterceptors = builder.getEffectiveInterceptors("unused:///"); assertThat(effectiveInterceptors).hasSize(2); - try { - InternalConfiguratorRegistry.setConfigurators(Collections.emptyList()); - fail("exception expected"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("Configurators are already set"); - } + InternalConfiguratorRegistry.setConfigurators(Collections.emptyList()); + assertThat(InternalConfiguratorRegistry.getConfigurators()).isEmpty(); + assertThat(InternalConfiguratorRegistry.getConfiguratorsCallCountBeforeSet()).isEqualTo(1); } } @@ -680,14 +691,12 @@ public void perRpcBufferLimit() { @Test public void retryBufferSizeInvalidArg() { - thrown.expect(IllegalArgumentException.class); - builder.retryBufferSize(0L); + assertThrows(IllegalArgumentException.class, () -> builder.retryBufferSize(0L)); } @Test public void perRpcBufferLimitInvalidArg() { - thrown.expect(IllegalArgumentException.class); - builder.perRpcBufferLimit(0L); + assertThrows(IllegalArgumentException.class, () -> builder.perRpcBufferLimit(0L)); } @Test @@ -710,8 +719,7 @@ public void defaultServiceConfig_nullKey() { Map config = new HashMap<>(); config.put(null, "val"); - thrown.expect(IllegalArgumentException.class); - builder.defaultServiceConfig(config); + assertThrows(IllegalArgumentException.class, () -> builder.defaultServiceConfig(config)); } @Test @@ -721,8 +729,7 @@ public void defaultServiceConfig_intKey() { Map config = new HashMap<>(); config.put("key", subConfig); - thrown.expect(IllegalArgumentException.class); - builder.defaultServiceConfig(config); + assertThrows(IllegalArgumentException.class, () -> builder.defaultServiceConfig(config)); } @Test @@ -730,8 +737,7 @@ public void defaultServiceConfig_intValue() { Map config = new HashMap<>(); config.put("key", 3); - thrown.expect(IllegalArgumentException.class); - builder.defaultServiceConfig(config); + assertThrows(IllegalArgumentException.class, () -> builder.defaultServiceConfig(config)); } @Test @@ -764,6 +770,16 @@ public void disableNameResolverServiceConfig() { assertThat(builder.lookUpServiceConfig).isFalse(); } + @Test + public void setNameResolverExtArgs() { + assertThat(builder.nameResolverCustomArgs) + .isNull(); + + NameResolver.Args.Key testKey = NameResolver.Args.Key.create("test-key"); + builder.setNameResolverArg(testKey, 42); + assertThat(builder.nameResolverCustomArgs.get(testKey)).isEqualTo(42); + } + @Test public void metricSinks() { MetricSink mocksink = mock(MetricSink.class); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverRfc3986Test.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverRfc3986Test.java new file mode 100644 index 00000000000..5bcf24a30e2 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverRfc3986Test.java @@ -0,0 +1,243 @@ +/* + * Copyright 2015 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.internal.UriWrapper.wrap; +import static org.junit.Assert.fail; + +import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; +import io.grpc.NameResolverRegistry; +import io.grpc.Uri; +import java.net.SocketAddress; +import java.net.URI; +import java.util.Collections; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for ManagedChannelImplBuilder#getNameResolverProviderNew(). */ +@RunWith(JUnit4.class) +public class ManagedChannelImplGetNameResolverRfc3986Test { + @Test + public void invalidUriTarget() { + testInvalidTarget("defaultscheme:///[invalid]"); + } + + @Test + public void invalidUnescapedSquareBracketsInRfc3986UriFragment() { + testInvalidTarget("defaultscheme://8.8.8.8/host#section[1]"); + } + + @Test + public void invalidUnescapedSquareBracketsInRfc3986UriQuery() { + testInvalidTarget("dns://8.8.8.8/path?section=[1]"); + } + + @Test + public void validTargetWithInvalidDnsName() throws Exception { + testValidTarget( + "[valid]", + "defaultscheme:///%5Bvalid%5D", + Uri.newBuilder().setScheme("defaultscheme").setHost("").setPath("/[valid]").build()); + } + + @Test + public void validAuthorityTarget() throws Exception { + testValidTarget( + "foo.googleapis.com:8080", + "defaultscheme:///foo.googleapis.com:8080", + Uri.newBuilder() + .setScheme("defaultscheme") + .setHost("") + .setPath("/foo.googleapis.com:8080") + .build()); + } + + @Test + public void validUriTarget() throws Exception { + testValidTarget( + "scheme:///foo.googleapis.com:8080", + "scheme:///foo.googleapis.com:8080", + Uri.newBuilder() + .setScheme("scheme") + .setHost("") + .setPath("/foo.googleapis.com:8080") + .build()); + } + + @Test + public void validIpv4AuthorityTarget() throws Exception { + testValidTarget( + "127.0.0.1:1234", + "defaultscheme:///127.0.0.1:1234", + Uri.newBuilder().setScheme("defaultscheme").setHost("").setPath("/127.0.0.1:1234").build()); + } + + @Test + public void validIpv4UriTarget() throws Exception { + testValidTarget( + "dns:///127.0.0.1:1234", + "dns:///127.0.0.1:1234", + Uri.newBuilder().setScheme("dns").setHost("").setPath("/127.0.0.1:1234").build()); + } + + @Test + public void validIpv6AuthorityTarget() throws Exception { + testValidTarget( + "[::1]:1234", + "defaultscheme:///%5B::1%5D:1234", + Uri.newBuilder().setScheme("defaultscheme").setHost("").setPath("/[::1]:1234").build()); + } + + @Test + public void invalidIpv6UriTarget() throws Exception { + testInvalidTarget("dns:///[::1]:1234"); + } + + @Test + public void invalidIpv6UriWithUnescapedScope() { + testInvalidTarget("dns://[::1%eth0]:53/host"); + } + + @Test + public void validIpv6UriTarget() throws Exception { + testValidTarget( + "dns:///%5B::1%5D:1234", + "dns:///%5B::1%5D:1234", + Uri.newBuilder().setScheme("dns").setHost("").setPath("/[::1]:1234").build()); + } + + @Test + public void validTargetStartingWithSlash() throws Exception { + testValidTarget( + "/target", + "defaultscheme:////target", + Uri.newBuilder().setScheme("defaultscheme").setHost("").setPath("//target").build()); + } + + @Test + public void validTargetNoProvider() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + try { + ManagedChannelImplBuilder.getNameResolverProviderRfc3986( + "foo.googleapis.com:8080", nameResolverRegistry); + fail("Should fail"); + } catch (IllegalArgumentException e) { + // expected + } + } + + @Test + public void validTargetProviderAddrTypesNotSupported() { + NameResolverRegistry nameResolverRegistry = getTestRegistry("testscheme"); + try { + ManagedChannelImplBuilder.getNameResolverProviderRfc3986( + "testscheme:///foo.googleapis.com:8080", nameResolverRegistry) + .checkAddressTypes(Collections.singleton(CustomSocketAddress.class)); + fail("Should fail"); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .isEqualTo( + "Address types of NameResolver 'testscheme' for " + + "'testscheme:///foo.googleapis.com:8080' not supported by transport"); + } + } + + private void testValidTarget(String target, String expectedUriString, Uri expectedUri) { + NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme()); + ManagedChannelImplBuilder.ResolvedNameResolver resolved = + ManagedChannelImplBuilder.getNameResolverProviderRfc3986(target, nameResolverRegistry); + assertThat(resolved.provider).isInstanceOf(FakeNameResolverProvider.class); + assertThat(resolved.targetUri).isEqualTo(wrap(expectedUri)); + assertThat(resolved.targetUri.toString()).isEqualTo(expectedUriString); + } + + private void testInvalidTarget(String target) { + NameResolverRegistry nameResolverRegistry = getTestRegistry("dns"); + + try { + ManagedChannelImplBuilder.ResolvedNameResolver resolved = + ManagedChannelImplBuilder.getNameResolverProviderRfc3986(target, nameResolverRegistry); + FakeNameResolverProvider nameResolverProvider = (FakeNameResolverProvider) resolved.provider; + fail("Should have failed, but got resolver provider " + nameResolverProvider); + } catch (IllegalArgumentException e) { + // expected + } + } + + private static NameResolverRegistry getTestRegistry(String expectedScheme) { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + FakeNameResolverProvider nameResolverProvider = new FakeNameResolverProvider(expectedScheme); + nameResolverRegistry.register(nameResolverProvider); + return nameResolverRegistry; + } + + private static class FakeNameResolverProvider extends NameResolverProvider { + final String expectedScheme; + + FakeNameResolverProvider(String expectedScheme) { + this.expectedScheme = expectedScheme; + } + + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + if (expectedScheme.equals(targetUri.getScheme())) { + return new FakeNameResolver(targetUri); + } + return null; + } + + @Override + public String getDefaultScheme() { + return expectedScheme; + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + } + + private static class FakeNameResolver extends NameResolver { + final URI uri; + + FakeNameResolver(URI uri) { + this.uri = uri; + } + + @Override + public String getServiceAuthority() { + return uri.getAuthority(); + } + + @Override + public void start(final Listener2 listener) {} + + @Override + public void shutdown() {} + } + + private static class CustomSocketAddress extends SocketAddress {} +} diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java index a0bd388b1b6..792f4daca4e 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java @@ -17,12 +17,12 @@ package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.internal.UriWrapper.wrap; import static org.junit.Assert.fail; import io.grpc.NameResolver; import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; -import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; import java.util.Collections; @@ -38,6 +38,21 @@ public void invalidUriTarget() { testInvalidTarget("defaultscheme:///[invalid]"); } + @Test + public void validSquareBracketsInRfc2396UriFragment() throws Exception { + testValidTarget("dns://8.8.8.8/host#section[1]", + "dns://8.8.8.8/host#section[1]", + new URI("dns", "8.8.8.8", "/host", null, "section[1]")); + } + + + @Test + public void validSquareBracketsInRfc2396UriQuery() throws Exception { + testValidTarget("dns://8.8.8.8/host?section=[1]", + "dns://8.8.8.8/host?section=[1]", + new URI("dns", "8.8.8.8", "/host", "section=[1]", null)); + } + @Test public void validTargetWithInvalidDnsName() throws Exception { testValidTarget("[valid]", "defaultscheme:///%5Bvalid%5D", @@ -74,6 +89,13 @@ public void validIpv6AuthorityTarget() throws Exception { new URI("defaultscheme", "", "/[::1]:1234", null)); } + @Test + public void validIpv6UriWithJavaNetUriScopeName() throws Exception { + testValidTarget("dns://[::1%eth0]:53/host", + "dns://[::1%eth0]:53/host", + new URI("dns", "[::1%eth0]:53", "/host", null, null)); + } + @Test public void invalidIpv6UriTarget() throws Exception { testInvalidTarget("dns:///[::1]:1234"); @@ -96,8 +118,7 @@ public void validTargetNoProvider() { NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); try { ManagedChannelImplBuilder.getNameResolverProvider( - "foo.googleapis.com:8080", nameResolverRegistry, - Collections.singleton(InetSocketAddress.class)); + "foo.googleapis.com:8080", nameResolverRegistry); fail("Should fail"); } catch (IllegalArgumentException e) { // expected @@ -109,8 +130,8 @@ public void validTargetProviderAddrTypesNotSupported() { NameResolverRegistry nameResolverRegistry = getTestRegistry("testscheme"); try { ManagedChannelImplBuilder.getNameResolverProvider( - "testscheme:///foo.googleapis.com:8080", nameResolverRegistry, - Collections.singleton(CustomSocketAddress.class)); + "testscheme:///foo.googleapis.com:8080", nameResolverRegistry) + .checkAddressTypes(Collections.singleton(CustomSocketAddress.class)); fail("Should fail"); } catch (IllegalArgumentException e) { assertThat(e).hasMessageThat().isEqualTo( @@ -122,10 +143,9 @@ public void validTargetProviderAddrTypesNotSupported() { private void testValidTarget(String target, String expectedUriString, URI expectedUri) { NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme()); ManagedChannelImplBuilder.ResolvedNameResolver resolved = - ManagedChannelImplBuilder.getNameResolverProvider( - target, nameResolverRegistry, Collections.singleton(InetSocketAddress.class)); + ManagedChannelImplBuilder.getNameResolverProvider(target, nameResolverRegistry); assertThat(resolved.provider).isInstanceOf(FakeNameResolverProvider.class); - assertThat(resolved.targetUri).isEqualTo(expectedUri); + assertThat(resolved.targetUri).isEqualTo(wrap(expectedUri)); assertThat(resolved.targetUri.toString()).isEqualTo(expectedUriString); } @@ -134,8 +154,7 @@ private void testInvalidTarget(String target) { try { ManagedChannelImplBuilder.ResolvedNameResolver resolved = - ManagedChannelImplBuilder.getNameResolverProvider( - target, nameResolverRegistry, Collections.singleton(InetSocketAddress.class)); + ManagedChannelImplBuilder.getNameResolverProvider(target, nameResolverRegistry); FakeNameResolverProvider nameResolverProvider = (FakeNameResolverProvider) resolved.provider; fail("Should have failed, but got resolver provider " + nameResolverProvider); } catch (IllegalArgumentException e) { diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java index 90008c1be30..97e92be7fdd 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.internal.UriWrapper.wrap; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -63,6 +64,7 @@ import io.grpc.NameResolver.ResolutionResult; import io.grpc.NameResolverProvider; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.StringMarshaller; import io.grpc.internal.FakeClock.ScheduledTask; import io.grpc.internal.ManagedChannelImplBuilder.UnsupportedClientTransportFactoryBuilder; @@ -184,7 +186,7 @@ public void setUp() { NameResolverProvider nameResolverProvider = builder.nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); channel = new ManagedChannelImpl( - builder, mockTransportFactory, targetUri, nameResolverProvider, + builder, mockTransportFactory, wrap(targetUri), nameResolverProvider, new FakeBackoffPolicyProvider(), oobExecutorPool, timer.getStopwatchSupplier(), Collections.emptyList(), @@ -615,7 +617,7 @@ private void deliverResolutionResult() { // the NameResolver. ResolutionResult resolutionResult = ResolutionResult.newBuilder() - .setAddresses(servers) + .setAddressesOrError(StatusOr.fromValue(servers)) .setAttributes(Attributes.EMPTY) .build(); nameResolverListenerCaptor.getValue().onResult(resolutionResult); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 4d42056b689..ae224af27e1 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -28,6 +28,7 @@ import static io.grpc.EquivalentAddressGroup.ATTR_AUTHORITY_OVERRIDE; import static io.grpc.PickSubchannelArgsMatcher.eqPickSubchannelArgs; import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; +import static io.grpc.internal.UriWrapper.wrap; import static junit.framework.TestCase.assertNotSame; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -57,6 +58,7 @@ import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; @@ -84,6 +86,7 @@ import io.grpc.InternalChannelz; import io.grpc.InternalChannelz.ChannelStats; import io.grpc.InternalChannelz.ChannelTrace; +import io.grpc.InternalChannelz.ChannelTrace.Event.Severity; import io.grpc.InternalConfigSelector; import io.grpc.InternalInstrumented; import io.grpc.LoadBalancer; @@ -115,6 +118,7 @@ import io.grpc.ServerMethodDefinition; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.StatusOr; import io.grpc.StringMarshaller; import io.grpc.SynchronizationContext; import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; @@ -123,6 +127,7 @@ import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider; import io.grpc.internal.ManagedChannelImplBuilder.UnsupportedClientTransportFactoryBuilder; +import io.grpc.internal.ManagedChannelServiceConfig.MethodInfo; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.internal.TestUtils.MockClientTransportInfo; import io.grpc.stub.ClientCalls; @@ -201,6 +206,8 @@ public class ManagedChannelImplTest { .setServiceConfigParser(mock(NameResolver.ServiceConfigParser.class)) .setScheduledExecutorService(new FakeClock().getScheduledExecutorService()) .build(); + private static final NameResolver.Args.Key TEST_RESOLVER_CUSTOM_ARG_KEY = + NameResolver.Args.Key.create("test-key"); private URI expectedUri; private final SocketAddress socketAddress = @@ -279,10 +286,6 @@ public String getPolicyName() { @Mock private ClientCall.Listener mockCallListener3; @Mock - private ClientCall.Listener mockCallListener4; - @Mock - private ClientCall.Listener mockCallListener5; - @Mock private ObjectPool executorPool; @Mock private ObjectPool balancerRpcExecutorPool; @@ -313,7 +316,7 @@ private void createChannel(boolean nameResolutionExpectedToFail, NameResolverProvider nameResolverProvider = channelBuilder.nameResolverRegistry.getProviderForScheme(expectedUri.getScheme()); channel = new ManagedChannelImpl( - channelBuilder, mockTransportFactory, expectedUri, nameResolverProvider, + channelBuilder, mockTransportFactory, wrap(expectedUri), nameResolverProvider, new FakeBackoffPolicyProvider(), balancerRpcExecutorPool, timer.getStopwatchSupplier(), Arrays.asList(interceptors), timer.getTimeProvider()); @@ -502,7 +505,7 @@ public void startCallBeforeNameResolution() throws Exception { when(mockTransportFactory.getSupportedSocketAddressTypes()).thenReturn(Collections.singleton( InetSocketAddress.class)); channel = new ManagedChannelImpl( - channelBuilder, mockTransportFactory, expectedUri, nameResolverFactory, + channelBuilder, mockTransportFactory, wrap(expectedUri), nameResolverFactory, new FakeBackoffPolicyProvider(), balancerRpcExecutorPool, timer.getStopwatchSupplier(), Collections.emptyList(), timer.getTimeProvider()); @@ -567,7 +570,7 @@ public void newCallWithConfigSelector() { when(mockTransportFactory.getSupportedSocketAddressTypes()).thenReturn(Collections.singleton( InetSocketAddress.class)); channel = new ManagedChannelImpl( - channelBuilder, mockTransportFactory, expectedUri, nameResolverFactory, + channelBuilder, mockTransportFactory, wrap(expectedUri), nameResolverFactory, new FakeBackoffPolicyProvider(), balancerRpcExecutorPool, timer.getStopwatchSupplier(), Collections.emptyList(), timer.getTimeProvider()); @@ -684,6 +687,30 @@ public void metricRecorder_recordsToMetricSink() { eq(optionalLabelValues)); } + @Test + public void metricRecorder_fromNameResolverArgs_recordsToMetricSink() { + MetricSink mockSink1 = mock(MetricSink.class); + MetricSink mockSink2 = mock(MetricSink.class); + channelBuilder.addMetricSink(mockSink1); + channelBuilder.addMetricSink(mockSink2); + createChannel(); + + LongCounterMetricInstrument counter = metricInstrumentRegistry.registerLongCounter( + "test_counter", "Time taken by metric recorder", "s", + ImmutableList.of("grpc.method"), Collections.emptyList(), false); + List requiredLabelValues = ImmutableList.of("testMethod"); + List optionalLabelValues = Collections.emptyList(); + + NameResolver.Args args = helper.getNameResolverArgs(); + assertThat(args.getMetricRecorder()).isNotNull(); + args.getMetricRecorder() + .addLongCounter(counter, 10, requiredLabelValues, optionalLabelValues); + verify(mockSink1).addLongCounter(eq(counter), eq(10L), eq(requiredLabelValues), + eq(optionalLabelValues)); + verify(mockSink2).addLongCounter(eq(counter), eq(10L), eq(requiredLabelValues), + eq(optionalLabelValues)); + } + @Test public void shutdownWithNoTransportsEverCreated() { channelBuilder.nameResolverFactory( @@ -768,7 +795,8 @@ public void channelzMembership_subchannel() throws Exception { transportInfo.listener.transportReady(); // terminate transport - transportInfo.listener.transportShutdown(Status.CANCELLED); + transportInfo.listener.transportShutdown(Status.CANCELLED, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo.listener.transportTerminated(); assertFalse(channelz.containsClientSocket(transportInfo.transport.getLogId())); @@ -786,46 +814,6 @@ public void channelzMembership_subchannel() throws Exception { assertNotNull(channelz.getRootChannel(channel.getLogId().getId())); } - @Test - public void channelzMembership_oob() throws Exception { - createChannel(); - OobChannel oob = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), AUTHORITY); - // oob channels are not root channels - assertNull(channelz.getRootChannel(oob.getLogId().getId())); - assertTrue(channelz.containsSubchannel(oob.getLogId())); - assertThat(getStats(channel).subchannels).containsExactly(oob); - assertTrue(channelz.containsSubchannel(oob.getLogId())); - - AbstractSubchannel subchannel = (AbstractSubchannel) oob.getSubchannel(); - assertTrue( - channelz.containsSubchannel(subchannel.getInstrumentedInternalSubchannel().getLogId())); - assertThat(getStats(oob).subchannels) - .containsExactly(subchannel.getInstrumentedInternalSubchannel()); - assertTrue( - channelz.containsSubchannel(subchannel.getInstrumentedInternalSubchannel().getLogId())); - - oob.getSubchannel().requestConnection(); - MockClientTransportInfo transportInfo = transports.poll(); - assertNotNull(transportInfo); - assertTrue(channelz.containsClientSocket(transportInfo.transport.getLogId())); - - // terminate transport - transportInfo.listener.transportShutdown(Status.INTERNAL); - transportInfo.listener.transportTerminated(); - assertFalse(channelz.containsClientSocket(transportInfo.transport.getLogId())); - - // terminate oobchannel - oob.shutdown(); - assertFalse(channelz.containsSubchannel(oob.getLogId())); - assertThat(getStats(channel).subchannels).isEmpty(); - assertFalse( - channelz.containsSubchannel(subchannel.getInstrumentedInternalSubchannel().getLogId())); - - // channel still appears - assertNotNull(channelz.getRootChannel(channel.getLogId().getId())); - } - @Test public void callsAndShutdown() { subtestCallsAndShutdown(false, false); @@ -972,7 +960,8 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft } // Killing the remaining real transport will terminate the channel - transportListener.transportShutdown(Status.UNAVAILABLE); + transportListener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertFalse(channel.isTerminated()); verify(executorPool, never()).returnObject(any()); transportListener.transportTerminated(); @@ -1042,7 +1031,8 @@ public void noMoreCallbackAfterLoadBalancerShutdown() { // Since subchannels are shutdown, SubchannelStateListeners will only get SHUTDOWN regardless of // the transport states. - transportInfo1.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo1.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo2.listener.transportReady(); verify(stateListener1).onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); verify(stateListener2).onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); @@ -1054,7 +1044,7 @@ public void noMoreCallbackAfterLoadBalancerShutdown() { verifyNoMoreInteractions(mockLoadBalancer); } - @Test + @Test public void noMoreCallbackAfterLoadBalancerShutdown_configError() throws InterruptedException { FakeNameResolverFactory nameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri) @@ -1093,7 +1083,10 @@ public void noMoreCallbackAfterLoadBalancerShutdown_configError() throws Interru verify(stateListener2).onSubchannelState(stateInfoCaptor.capture()); assertSame(CONNECTING, stateInfoCaptor.getValue().getState()); - resolver.listener.onError(resolutionError); + channel.syncContext.execute(() -> + resolver.listener.onResult2( + ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromStatus(resolutionError)).build())); verify(mockLoadBalancer).handleNameResolutionError(resolutionError); verifyNoMoreInteractions(mockLoadBalancer); @@ -1108,25 +1101,74 @@ public void noMoreCallbackAfterLoadBalancerShutdown_configError() throws Interru // Since subchannels are shutdown, SubchannelStateListeners will only get SHUTDOWN regardless of // the transport states. - transportInfo1.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo1.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo2.listener.transportReady(); verify(stateListener1).onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); verify(stateListener2).onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); verifyNoMoreInteractions(stateListener1, stateListener2); // No more callback should be delivered to LoadBalancer after it's shut down - resolver.listener.onResult( - ResolutionResult.newBuilder() - .setAddresses(new ArrayList<>()) - .setServiceConfig( - ConfigOrError.fromError(Status.UNAVAILABLE.withDescription("Resolution failed"))) - .build()); - Thread.sleep(1100); + channel.syncContext.execute(() -> + resolver.listener.onResult2( + ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromStatus(resolutionError)).build())); assertThat(timer.getPendingTasks()).isEmpty(); resolver.resolved(); verifyNoMoreInteractions(mockLoadBalancer); } + @Test + public void addressResolutionError_noPriorNameResolution_usesDefaultServiceConfig() + throws Exception { + Map rawServiceConfig = + parseConfig("{\"methodConfig\":[{" + + "\"name\":[{\"service\":\"service\"}]," + + "\"waitForReady\":true}]}"); + ManagedChannelServiceConfig managedChannelServiceConfig = + createManagedChannelServiceConfig(rawServiceConfig, null); + FakeNameResolverFactory nameResolverFactory = + new FakeNameResolverFactory.Builder(expectedUri) + .setServers(Collections.singletonList(new EquivalentAddressGroup(socketAddress))) + .setResolvedAtStart(false) + .build(); + nameResolverFactory.nextConfigOrError.set( + ConfigOrError.fromConfig(managedChannelServiceConfig)); + channelBuilder.nameResolverFactory(nameResolverFactory); + Map defaultServiceConfig = + parseConfig("{\"methodConfig\":[{" + + "\"name\":[{\"service\":\"service\"}]," + + "\"waitForReady\":true}]}"); + channelBuilder.defaultServiceConfig(defaultServiceConfig); + Status resolutionError = Status.UNAVAILABLE.withDescription("Resolution failed"); + channelBuilder.maxTraceEvents(10); + createChannel(); + FakeNameResolverFactory.FakeNameResolver resolver = nameResolverFactory.resolvers.get(0); + + resolver.listener.onError(resolutionError); + + InternalConfigSelector configSelector = channel.getConfigSelector(); + ManagedChannelServiceConfig config = + (ManagedChannelServiceConfig) configSelector.selectConfig(null).getConfig(); + MethodInfo methodConfig = config.getMethodConfig(method); + assertThat(methodConfig.waitForReady).isTrue(); + timer.forwardNanos(1234); + assertThat(getStats(channel).channelTrace.events).contains(new ChannelTrace.Event.Builder() + .setDescription("Initial Name Resolution error, using default service config") + .setSeverity(Severity.CT_ERROR) + .setTimestampNanos(0) + .build()); + + // Check that "lastServiceConfig" variable has been set above: a config resolution with the same + // config simply gets ignored and not gets reassigned. + resolver.resolved(); + timer.forwardNanos(1234); + assertThat(Iterables.filter( + getStats(channel).channelTrace.events, + event -> event.description.equals("Service config changed"))) + .isEmpty(); + } + @Test public void interceptor() throws Exception { final AtomicLong atomic = new AtomicLong(); @@ -1196,7 +1238,8 @@ public void callOptionsExecutor() { verify(mockCallListener).onClose(same(Status.CANCELLED), same(trailers)); - transportListener.transportShutdown(Status.UNAVAILABLE); + transportListener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportListener.transportTerminated(); // Clean up as much as possible to allow the channel to terminate. @@ -1262,7 +1305,7 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata header PickResult.withSubchannel(subchannel)); updateBalancingStateSafely(helper, READY, mockPicker); - assertEquals(2, executor.runDueTasks()); + assertEquals(3, executor.runDueTasks()); verify(mockPicker).pickSubchannel(any(PickSubchannelArgs.class)); verify(mockTransport).newStream( @@ -1348,7 +1391,8 @@ public void firstResolvedServerFailedToConnect() throws Exception { MockClientTransportInfo badTransportInfo = transports.poll(); // Which failed to connect - badTransportInfo.listener.transportShutdown(Status.UNAVAILABLE); + badTransportInfo.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); inOrder.verifyNoMoreInteractions(); // The channel then try the second address (goodAddress) @@ -1498,7 +1542,8 @@ public void allServersFailedToConnect() throws Exception { .newClientTransport( same(addr2), any(ClientTransportOptions.class), any(ChannelLogger.class)); MockClientTransportInfo transportInfo1 = transports.poll(); - transportInfo1.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo1.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // Connecting to server2, which will fail too verify(mockTransportFactory) @@ -1506,7 +1551,8 @@ public void allServersFailedToConnect() throws Exception { same(addr2), any(ClientTransportOptions.class), any(ChannelLogger.class)); MockClientTransportInfo transportInfo2 = transports.poll(); Status server2Error = Status.UNAVAILABLE.withDescription("Server2 failed to connect"); - transportInfo2.listener.transportShutdown(server2Error); + transportInfo2.listener.transportShutdown(server2Error, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // ... which makes the subchannel enter TRANSIENT_FAILURE. The last error Status is propagated // to LoadBalancer. @@ -1616,9 +1662,11 @@ public void run() { verify(transportInfo2.transport).shutdown(same(ManagedChannelImpl.SHUTDOWN_STATUS)); // Cleanup - transportInfo1.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo1.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo1.listener.transportTerminated(); - transportInfo2.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo2.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo2.listener.transportTerminated(); timer.forwardTime(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS, TimeUnit.SECONDS); } @@ -1658,8 +1706,10 @@ public void subchannelsWhenChannelShutdownNow() { verify(ti1.transport).shutdownNow(any(Status.class)); verify(ti2.transport).shutdownNow(any(Status.class)); - ti1.listener.transportShutdown(Status.UNAVAILABLE.withDescription("shutdown now")); - ti2.listener.transportShutdown(Status.UNAVAILABLE.withDescription("shutdown now")); + ti1.listener.transportShutdown(Status.UNAVAILABLE.withDescription("shutdown now"), + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + ti2.listener.transportShutdown(Status.UNAVAILABLE.withDescription("shutdown now"), + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); ti1.listener.transportTerminated(); assertFalse(channel.isTerminated()); @@ -1686,6 +1736,19 @@ public void subchannelsNoConnectionShutdown() { any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); } + @Test + public void subchannelsRequestConnectionNoopAfterShutdown() { + createChannel(); + Subchannel sub1 = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); + + shutdownSafely(helper, sub1); + requestConnectionSafely(helper, sub1); + verify(mockTransportFactory, never()) + .newClientTransport( + any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); + } + @Test public void subchannelsNoConnectionShutdownNow() { createChannel(); @@ -1694,7 +1757,7 @@ public void subchannelsNoConnectionShutdownNow() { channel.shutdownNow(); verify(mockLoadBalancer).shutdown(); - // Channel's shutdownNow() will call shutdownNow() on all subchannels and oobchannels. + // Channel's shutdownNow() will call shutdownNow() on all subchannels. // Therefore, channel is terminated without relying on LoadBalancer to shutdown subchannels. assertTrue(channel.isTerminated()); verify(mockTransportFactory, never()) @@ -1702,112 +1765,6 @@ public void subchannelsNoConnectionShutdownNow() { any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); } - @Test - public void oobchannels() { - createChannel(); - - ManagedChannel oob1 = helper.createOobChannel( - Collections.singletonList(addressGroup), "oob1authority"); - ManagedChannel oob2 = helper.createOobChannel( - Collections.singletonList(addressGroup), "oob2authority"); - verify(balancerRpcExecutorPool, times(2)).getObject(); - - assertEquals("oob1authority", oob1.authority()); - assertEquals("oob2authority", oob2.authority()); - - // OOB channels create connections lazily. A new call will initiate the connection. - Metadata headers = new Metadata(); - ClientCall call = oob1.newCall(method, CallOptions.DEFAULT); - call.start(mockCallListener, headers); - verify(mockTransportFactory) - .newClientTransport( - eq(socketAddress), - eq(new ClientTransportOptions().setAuthority("oob1authority").setUserAgent(USER_AGENT)), - isA(ChannelLogger.class)); - MockClientTransportInfo transportInfo = transports.poll(); - assertNotNull(transportInfo); - - assertEquals(0, balancerRpcExecutor.numPendingTasks()); - transportInfo.listener.transportReady(); - assertEquals(1, balancerRpcExecutor.runDueTasks()); - verify(transportInfo.transport).newStream( - same(method), same(headers), same(CallOptions.DEFAULT), - ArgumentMatchers.any()); - - // The transport goes away - transportInfo.listener.transportShutdown(Status.UNAVAILABLE); - transportInfo.listener.transportTerminated(); - - // A new call will trigger a new transport - ClientCall call2 = oob1.newCall(method, CallOptions.DEFAULT); - call2.start(mockCallListener2, headers); - ClientCall call3 = - oob1.newCall(method, CallOptions.DEFAULT.withWaitForReady()); - call3.start(mockCallListener3, headers); - verify(mockTransportFactory, times(2)).newClientTransport( - eq(socketAddress), - eq(new ClientTransportOptions().setAuthority("oob1authority").setUserAgent(USER_AGENT)), - isA(ChannelLogger.class)); - transportInfo = transports.poll(); - assertNotNull(transportInfo); - - // This transport fails - Status transportError = Status.UNAVAILABLE.withDescription("Connection refused"); - assertEquals(0, balancerRpcExecutor.numPendingTasks()); - transportInfo.listener.transportShutdown(transportError); - assertTrue(balancerRpcExecutor.runDueTasks() > 0); - - // Fail-fast RPC will fail, while wait-for-ready RPC will still be pending - verify(mockCallListener2).onClose(same(transportError), any(Metadata.class)); - verify(mockCallListener3, never()).onClose(any(Status.class), any(Metadata.class)); - - // Shutdown - assertFalse(oob1.isShutdown()); - assertFalse(oob2.isShutdown()); - oob1.shutdown(); - oob2.shutdownNow(); - assertTrue(oob1.isShutdown()); - assertTrue(oob2.isShutdown()); - assertTrue(oob2.isTerminated()); - verify(balancerRpcExecutorPool).returnObject(balancerRpcExecutor.getScheduledExecutorService()); - - // New RPCs will be rejected. - assertEquals(0, balancerRpcExecutor.numPendingTasks()); - ClientCall call4 = oob1.newCall(method, CallOptions.DEFAULT); - ClientCall call5 = oob2.newCall(method, CallOptions.DEFAULT); - call4.start(mockCallListener4, headers); - call5.start(mockCallListener5, headers); - assertTrue(balancerRpcExecutor.runDueTasks() > 0); - verify(mockCallListener4).onClose(statusCaptor.capture(), any(Metadata.class)); - Status status4 = statusCaptor.getValue(); - assertEquals(Status.Code.UNAVAILABLE, status4.getCode()); - verify(mockCallListener5).onClose(statusCaptor.capture(), any(Metadata.class)); - Status status5 = statusCaptor.getValue(); - assertEquals(Status.Code.UNAVAILABLE, status5.getCode()); - - // The pending RPC will still be pending - verify(mockCallListener3, never()).onClose(any(Status.class), any(Metadata.class)); - - // This will shutdownNow() the delayed transport, terminating the pending RPC - assertEquals(0, balancerRpcExecutor.numPendingTasks()); - oob1.shutdownNow(); - assertTrue(balancerRpcExecutor.runDueTasks() > 0); - verify(mockCallListener3).onClose(any(Status.class), any(Metadata.class)); - - // Shut down the channel, and it will not terminated because OOB channel has not. - channel.shutdown(); - assertFalse(channel.isTerminated()); - // Delayed transport has already terminated. Terminating the transport terminates the - // subchannel, which in turn terimates the OOB channel, which terminates the channel. - assertFalse(oob1.isTerminated()); - verify(balancerRpcExecutorPool).returnObject(balancerRpcExecutor.getScheduledExecutorService()); - transportInfo.listener.transportTerminated(); - assertTrue(oob1.isTerminated()); - assertTrue(channel.isTerminated()); - verify(balancerRpcExecutorPool, times(2)) - .returnObject(balancerRpcExecutor.getScheduledExecutorService()); - } - @Test public void oobChannelHasNoChannelCallCredentials() { Metadata.Key metadataKey = @@ -1859,7 +1816,7 @@ public void oobChannelHasNoChannelCallCredentials() { balancerRpcExecutor.runDueTasks(); verify(transportInfo.transport).newStream( - same(method), same(headers), same(callOptions), + same(method), same(headers), ArgumentMatchers.any(), ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue); oob.shutdownNow(); @@ -1986,74 +1943,6 @@ public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { oob.shutdownNow(); } - @Test - public void oobChannelsWhenChannelShutdownNow() { - createChannel(); - ManagedChannel oob1 = helper.createOobChannel( - Collections.singletonList(addressGroup), "oob1Authority"); - ManagedChannel oob2 = helper.createOobChannel( - Collections.singletonList(addressGroup), "oob2Authority"); - - oob1.newCall(method, CallOptions.DEFAULT).start(mockCallListener, new Metadata()); - oob2.newCall(method, CallOptions.DEFAULT).start(mockCallListener2, new Metadata()); - - assertThat(transports).hasSize(2); - MockClientTransportInfo ti1 = transports.poll(); - MockClientTransportInfo ti2 = transports.poll(); - - ti1.listener.transportReady(); - ti2.listener.transportReady(); - - channel.shutdownNow(); - verify(ti1.transport).shutdownNow(any(Status.class)); - verify(ti2.transport).shutdownNow(any(Status.class)); - - ti1.listener.transportShutdown(Status.UNAVAILABLE.withDescription("shutdown now")); - ti2.listener.transportShutdown(Status.UNAVAILABLE.withDescription("shutdown now")); - ti1.listener.transportTerminated(); - - assertFalse(channel.isTerminated()); - ti2.listener.transportTerminated(); - assertTrue(channel.isTerminated()); - } - - @Test - public void oobChannelsNoConnectionShutdown() { - createChannel(); - ManagedChannel oob1 = helper.createOobChannel( - Collections.singletonList(addressGroup), "oob1Authority"); - ManagedChannel oob2 = helper.createOobChannel( - Collections.singletonList(addressGroup), "oob2Authority"); - channel.shutdown(); - - verify(mockLoadBalancer).shutdown(); - oob1.shutdown(); - assertTrue(oob1.isTerminated()); - assertFalse(channel.isTerminated()); - oob2.shutdown(); - assertTrue(oob2.isTerminated()); - assertTrue(channel.isTerminated()); - verify(mockTransportFactory, never()) - .newClientTransport( - any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); - } - - @Test - public void oobChannelsNoConnectionShutdownNow() { - createChannel(); - helper.createOobChannel(Collections.singletonList(addressGroup), "oob1Authority"); - helper.createOobChannel(Collections.singletonList(addressGroup), "oob2Authority"); - channel.shutdownNow(); - - verify(mockLoadBalancer).shutdown(); - assertTrue(channel.isTerminated()); - // Channel's shutdownNow() will call shutdownNow() on all subchannels and oobchannels. - // Therefore, channel is terminated without relying on LoadBalancer to shutdown oobchannels. - verify(mockTransportFactory, never()) - .newClientTransport( - any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); - } - @Test public void subchannelChannel_normalUsage() { createChannel(); @@ -2188,6 +2077,7 @@ public void lbHelper_getNameResolverArgs() { assertThat(args.getSynchronizationContext()) .isSameInstanceAs(helper.getSynchronizationContext()); assertThat(args.getServiceConfigParser()).isNotNull(); + assertThat(args.getMetricRecorder()).isNotNull(); } @Test @@ -2198,67 +2088,6 @@ public void lbHelper_getNonDefaultNameResolverRegistry() { .isNotSameInstanceAs(NameResolverRegistry.getDefaultRegistry()); } - @Test - public void refreshNameResolution_whenOobChannelConnectionFailed_notIdle() { - subtestNameResolutionRefreshWhenConnectionFailed(false); - } - - @Test - public void notRefreshNameResolution_whenOobChannelConnectionFailed_idle() { - subtestNameResolutionRefreshWhenConnectionFailed(true); - } - - private void subtestNameResolutionRefreshWhenConnectionFailed(boolean isIdle) { - FakeNameResolverFactory nameResolverFactory = - new FakeNameResolverFactory.Builder(expectedUri) - .setServers(Collections.singletonList(new EquivalentAddressGroup(socketAddress))) - .build(); - channelBuilder.nameResolverFactory(nameResolverFactory); - createChannel(); - OobChannel oobChannel = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), "oobAuthority"); - oobChannel.getSubchannel().requestConnection(); - - MockClientTransportInfo transportInfo = transports.poll(); - assertNotNull(transportInfo); - - FakeNameResolverFactory.FakeNameResolver resolver = nameResolverFactory.resolvers.remove(0); - - if (isIdle) { - channel.enterIdle(); - // Entering idle mode will result in a new resolver - resolver = nameResolverFactory.resolvers.remove(0); - } - - assertEquals(0, nameResolverFactory.resolvers.size()); - - int expectedRefreshCount = 0; - - // Transport closed when connecting - assertEquals(expectedRefreshCount, resolver.refreshCalled); - transportInfo.listener.transportShutdown(Status.UNAVAILABLE); - // When channel enters idle, new resolver is created but not started. - if (!isIdle) { - expectedRefreshCount++; - } - assertEquals(expectedRefreshCount, resolver.refreshCalled); - - timer.forwardNanos(RECONNECT_BACKOFF_INTERVAL_NANOS); - transportInfo = transports.poll(); - assertNotNull(transportInfo); - - transportInfo.listener.transportReady(); - - // Transport closed when ready - assertEquals(expectedRefreshCount, resolver.refreshCalled); - transportInfo.listener.transportShutdown(Status.UNAVAILABLE); - // When channel enters idle, new resolver is created but not started. - if (!isIdle) { - expectedRefreshCount++; - } - assertEquals(expectedRefreshCount, resolver.refreshCalled); - } - /** * Test that information such as the Call's context, MethodDescriptor, authority, executor are * propagated to newStream() and applyRequestMetadata(). @@ -2508,7 +2337,6 @@ public void getState_withRequestConnect_IdleWithLbRunning() { assertEquals(IDLE, channel.getState(true)); verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); - verify(mockPicker).requestConnection(); verify(mockLoadBalancer).requestConnection(); } @@ -3235,11 +3063,19 @@ public void channelTracing_nameResolvedEvent_zeorAndNonzeroBackends_usesListener assertThat(getStats(channel).channelTrace.events).hasSize(prevSize); prevSize = getStats(channel).channelTrace.events.size(); - nameResolverFactory.resolvers.get(0).listener.onError(Status.INTERNAL); + channel.syncContext.execute(() -> + nameResolverFactory.resolvers.get(0).listener.onResult2( + ResolutionResult.newBuilder() + .setAddressesOrError( + StatusOr.fromStatus(Status.INTERNAL)).build())); assertThat(getStats(channel).channelTrace.events).hasSize(prevSize + 1); prevSize = getStats(channel).channelTrace.events.size(); - nameResolverFactory.resolvers.get(0).listener.onError(Status.INTERNAL); + channel.syncContext.execute(() -> + nameResolverFactory.resolvers.get(0).listener.onResult2( + ResolutionResult.newBuilder() + .setAddressesOrError( + StatusOr.fromStatus(Status.INTERNAL)).build())); assertThat(getStats(channel).channelTrace.events).hasSize(prevSize); prevSize = getStats(channel).channelTrace.events.size(); @@ -3404,48 +3240,6 @@ public void channelTracing_subchannelStateChangeEvent() throws Exception { .build()); } - @Test - public void channelTracing_oobChannelStateChangeEvent() throws Exception { - channelBuilder.maxTraceEvents(10); - createChannel(); - OobChannel oobChannel = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), "authority"); - timer.forwardNanos(1234); - oobChannel.handleSubchannelStateChange( - ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); - assertThat(getStats(oobChannel).channelTrace.events).contains(new ChannelTrace.Event.Builder() - .setDescription("Entering CONNECTING state") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(timer.getTicker().read()) - .build()); - } - - @Test - public void channelTracing_oobChannelCreationEvents() throws Exception { - channelBuilder.maxTraceEvents(10); - createChannel(); - timer.forwardNanos(1234); - OobChannel oobChannel = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), "authority"); - assertThat(getStats(channel).channelTrace.events).contains(new ChannelTrace.Event.Builder() - .setDescription("Child OobChannel created") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(timer.getTicker().read()) - .setChannelRef(oobChannel) - .build()); - assertThat(getStats(oobChannel).channelTrace.events).contains(new ChannelTrace.Event.Builder() - .setDescription("OobChannel for [[[test-addr]/{}]] created") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(timer.getTicker().read()) - .build()); - assertThat(getStats(oobChannel.getInternalSubchannel()).channelTrace.events).contains( - new ChannelTrace.Event.Builder() - .setDescription("Subchannel for [[[test-addr]/{}]] created") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(timer.getTicker().read()) - .build()); - } - @Test public void channelsAndSubchannels_instrumented_state() throws Exception { createChannel(); @@ -3561,115 +3355,6 @@ private void channelsAndSubchannels_instrumented0(boolean success) throws Except } } - @Test - public void channelsAndSubchannels_oob_instrumented_success() throws Exception { - channelsAndSubchannels_oob_instrumented0(true); - } - - @Test - public void channelsAndSubchannels_oob_instrumented_fail() throws Exception { - channelsAndSubchannels_oob_instrumented0(false); - } - - private void channelsAndSubchannels_oob_instrumented0(boolean success) throws Exception { - // set up - ClientStream mockStream = mock(ClientStream.class); - createChannel(); - - OobChannel oobChannel = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), "oobauthority"); - AbstractSubchannel oobSubchannel = (AbstractSubchannel) oobChannel.getSubchannel(); - FakeClock callExecutor = new FakeClock(); - CallOptions options = - CallOptions.DEFAULT.withExecutor(callExecutor.getScheduledExecutorService()); - ClientCall call = oobChannel.newCall(method, options); - Metadata headers = new Metadata(); - - // Channel stat bumped when ClientCall.start() called - assertEquals(0, getStats(oobChannel).callsStarted); - call.start(mockCallListener, headers); - assertEquals(1, getStats(oobChannel).callsStarted); - - MockClientTransportInfo transportInfo = transports.poll(); - ConnectionClientTransport mockTransport = transportInfo.transport; - ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream( - same(method), same(headers), any(CallOptions.class), - ArgumentMatchers.any())) - .thenReturn(mockStream); - - // subchannel stat bumped when call gets assigned to it - assertEquals(0, getStats(oobSubchannel).callsStarted); - transportListener.transportReady(); - callExecutor.runDueTasks(); - verify(mockStream).start(streamListenerCaptor.capture()); - assertEquals(1, getStats(oobSubchannel).callsStarted); - - ClientStreamListener streamListener = streamListenerCaptor.getValue(); - call.halfClose(); - - // closing stream listener affects subchannel stats immediately - assertEquals(0, getStats(oobSubchannel).callsSucceeded); - assertEquals(0, getStats(oobSubchannel).callsFailed); - streamListener.closed(success ? Status.OK : Status.UNKNOWN, PROCESSED, new Metadata()); - if (success) { - assertEquals(1, getStats(oobSubchannel).callsSucceeded); - assertEquals(0, getStats(oobSubchannel).callsFailed); - } else { - assertEquals(0, getStats(oobSubchannel).callsSucceeded); - assertEquals(1, getStats(oobSubchannel).callsFailed); - } - - // channel stats bumped when the ClientCall.Listener is notified - assertEquals(0, getStats(oobChannel).callsSucceeded); - assertEquals(0, getStats(oobChannel).callsFailed); - callExecutor.runDueTasks(); - if (success) { - assertEquals(1, getStats(oobChannel).callsSucceeded); - assertEquals(0, getStats(oobChannel).callsFailed); - } else { - assertEquals(0, getStats(oobChannel).callsSucceeded); - assertEquals(1, getStats(oobChannel).callsFailed); - } - // oob channel is separate from the original channel - assertEquals(0, getStats(channel).callsSucceeded); - assertEquals(0, getStats(channel).callsFailed); - } - - @Test - public void channelsAndSubchannels_oob_instrumented_name() throws Exception { - createChannel(); - - String authority = "oobauthority"; - OobChannel oobChannel = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), authority); - assertEquals(authority, getStats(oobChannel).target); - } - - @Test - public void channelsAndSubchannels_oob_instrumented_state() throws Exception { - createChannel(); - - OobChannel oobChannel = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), "oobauthority"); - assertEquals(IDLE, getStats(oobChannel).state); - - oobChannel.getSubchannel().requestConnection(); - assertEquals(CONNECTING, getStats(oobChannel).state); - - MockClientTransportInfo transportInfo = transports.poll(); - ManagedClientTransport.Listener transportListener = transportInfo.listener; - - transportListener.transportReady(); - assertEquals(READY, getStats(oobChannel).state); - - // oobchannel state is separate from the ManagedChannel - assertEquals(CONNECTING, getStats(channel).state); - channel.shutdownNow(); - assertEquals(SHUTDOWN, getStats(channel).state); - assertEquals(SHUTDOWN, getStats(oobChannel).state); - } - @Test public void binaryLogInstalled() throws Exception { final SettableFuture intercepted = SettableFuture.create(); @@ -3755,8 +3440,6 @@ public double nextDouble() { verify(mockLoadBalancer).acceptResolvedAddresses(resolvedAddressCaptor.capture()); ResolvedAddresses resolvedAddresses = resolvedAddressCaptor.getValue(); assertThat(resolvedAddresses.getAddresses()).isEqualTo(nameResolverFactory.servers); - assertThat(resolvedAddresses.getAttributes() - .get(RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY)).isNotNull(); // simulating request connection and then transport ready after resolved address Subchannel subchannel = @@ -3785,7 +3468,7 @@ public double nextDouble() { Status.UNAVAILABLE, PROCESSED, new Metadata()); // in backoff - timer.forwardTime(5, TimeUnit.SECONDS); + timer.forwardTime(6, TimeUnit.SECONDS); assertThat(timer.getPendingTasks()).hasSize(1); verify(mockStream2, never()).start(any(ClientStreamListener.class)); @@ -3804,7 +3487,7 @@ public double nextDouble() { assertEquals("Channel shutdown invoked", statusCaptor.getValue().getDescription()); // backoff ends - timer.forwardTime(5, TimeUnit.SECONDS); + timer.forwardTime(6, TimeUnit.SECONDS); assertThat(timer.getPendingTasks()).isEmpty(); verify(mockStream2).start(streamListenerCaptor.capture()); verify(mockLoadBalancer, never()).shutdown(); @@ -3817,7 +3500,8 @@ public double nextDouble() { verify(mockLoadBalancer).shutdown(); // simulating the shutdown of load balancer triggers the shutdown of subchannel shutdownSafely(helper, subchannel); - transportInfo.listener.transportShutdown(Status.INTERNAL); + transportInfo.listener.transportShutdown(Status.INTERNAL, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo.listener.transportTerminated(); // simulating transport terminated assertTrue( "channel.isTerminated() is expected to be true but was false", @@ -3862,8 +3546,6 @@ public void hedgingScheduledThenChannelShutdown_hedgeShouldStillHappen_newCallSh verify(mockLoadBalancer).acceptResolvedAddresses(resolvedAddressCaptor.capture()); ResolvedAddresses resolvedAddresses = resolvedAddressCaptor.getValue(); assertThat(resolvedAddresses.getAddresses()).isEqualTo(nameResolverFactory.servers); - assertThat(resolvedAddresses.getAttributes() - .get(RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY)).isNotNull(); // simulating request connection and then transport ready after resolved address Subchannel subchannel = @@ -3924,7 +3606,8 @@ public void hedgingScheduledThenChannelShutdown_hedgeShouldStillHappen_newCallSh // simulating the shutdown of load balancer triggers the shutdown of subchannel shutdownSafely(helper, subchannel); // simulating transport shutdown & terminated - transportInfo.listener.transportShutdown(Status.INTERNAL); + transportInfo.listener.transportShutdown(Status.INTERNAL, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo.listener.transportTerminated(); assertTrue( "channel.isTerminated() is expected to be true but was false", @@ -4183,13 +3866,18 @@ public String getDefaultScheme() { return "fake"; } }; - channelBuilder.nameResolverFactory(factory).proxyDetector(neverProxy); + channelBuilder + .nameResolverFactory(factory) + .proxyDetector(neverProxy) + .setNameResolverArg(TEST_RESOLVER_CUSTOM_ARG_KEY, "test-value"); + createChannel(); NameResolver.Args args = capturedArgs.get(); assertThat(args).isNotNull(); assertThat(args.getDefaultPort()).isEqualTo(DEFAULT_PORT); assertThat(args.getProxyDetector()).isSameInstanceAs(neverProxy); + assertThat(args.getArg(TEST_RESOLVER_CUSTOM_ARG_KEY)).isEqualTo("test-value"); verify(offloadExecutor, never()).execute(any(Runnable.class)); args.getOffloadExecutor() @@ -4254,13 +3942,37 @@ public void nameResolverHelper_badConfigFails() { assertThat(coe.getError().getCause()).isInstanceOf(ClassCastException.class); } + @Test + public void nameResolverHelper_badParser_failsGracefully() { + boolean retryEnabled = false; + int maxRetryAttemptsLimit = 2; + int maxHedgedAttemptsLimit = 3; + + Throwable t = new Error("really poor config parser"); + when(mockLoadBalancerProvider.parseLoadBalancingPolicyConfig(any())).thenThrow(t); + ScParser parser = new ScParser( + retryEnabled, + maxRetryAttemptsLimit, + maxHedgedAttemptsLimit, + mockLoadBalancerProvider); + + ConfigOrError coe = parser.parseServiceConfig(ImmutableMap.of()); + + assertThat(coe.getError()).isNotNull(); + assertThat(coe.getError().getCode()).isEqualTo(Code.INTERNAL); + assertThat(coe.getError().getDescription()).contains("Unexpected error parsing service config"); + assertThat(coe.getError().getCause()).isSameInstanceAs(t); + } + @Test public void nameResolverHelper_noConfigChosen() { boolean retryEnabled = false; int maxRetryAttemptsLimit = 2; int maxHedgedAttemptsLimit = 3; + LoadBalancerRegistry registry = new LoadBalancerRegistry(); + registry.register(mockLoadBalancerProvider); AutoConfiguredLoadBalancerFactory autoConfiguredLoadBalancerFactory = - new AutoConfiguredLoadBalancerFactory("pick_first"); + new AutoConfiguredLoadBalancerFactory(registry, MOCK_POLICY_NAME); ScParser parser = new ScParser( retryEnabled, @@ -4595,7 +4307,7 @@ public void notUseDefaultImmediatelyIfEnableLookUp() throws Exception { int size = getStats(channel).channelTrace.events.size(); assertThat(getStats(channel).channelTrace.events.get(size - 1)) .isNotEqualTo(new ChannelTrace.Event.Builder() - .setDescription("Using default service config") + .setDescription("timer.forwardNanos(1234);") .setSeverity(ChannelTrace.Event.Severity.CT_INFO) .setTimestampNanos(timer.getTicker().read()) .build()); @@ -4689,7 +4401,7 @@ public void transportTerminated(Attributes transportAttrs) { assertEquals(1, readyCallbackCalled.get()); assertEquals(0, terminationCallbackCalled.get()); - transportListener.transportShutdown(Status.OK); + transportListener.transportShutdown(Status.OK, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportListener.transportTerminated(); assertEquals(1, terminationCallbackCalled.get()); @@ -4727,11 +4439,11 @@ public void validAuthorityTarget_overrideAuthority() throws Exception { URI targetUri = new URI("defaultscheme", "", "/foo.googleapis.com:8080", null); NameResolver nameResolver = ManagedChannelImpl.getNameResolver( - targetUri, null, nameResolverProvider, NAMERESOLVER_ARGS); + wrap(targetUri), null, nameResolverProvider, NAMERESOLVER_ARGS); assertThat(nameResolver.getServiceAuthority()).isEqualTo(serviceAuthority); nameResolver = ManagedChannelImpl.getNameResolver( - targetUri, overrideAuthority, nameResolverProvider, NAMERESOLVER_ARGS); + wrap(targetUri), overrideAuthority, nameResolverProvider, NAMERESOLVER_ARGS); assertThat(nameResolver.getServiceAuthority()).isEqualTo(overrideAuthority); } @@ -4760,7 +4472,7 @@ public String getDefaultScheme() { }; try { ManagedChannelImpl.getNameResolver( - URI.create("defaultscheme:///foo.gogoleapis.com:8080"), + wrap(URI.create("defaultscheme:///foo.gogoleapis.com:8080")), null, nameResolverProvider, NAMERESOLVER_ARGS); fail("Should fail"); } catch (IllegalArgumentException e) { @@ -4868,7 +4580,10 @@ final class FakeNameResolver extends NameResolver { void resolved() { if (error != null) { - listener.onError(error); + syncContext.execute(() -> + listener.onResult2( + ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromStatus(error)).build())); return; } ResolutionResult.Builder builder = diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelOrphanWrapperTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelOrphanWrapperTest.java index 5ae97c69211..45fb3881722 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelOrphanWrapperTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelOrphanWrapperTest.java @@ -101,6 +101,45 @@ public boolean isDone() { } } + @Test + public void shutdown_withDelegateStillReferenced_doesNotLogWarning() { + ManagedChannel mc = new TestManagedChannel(); + final ReferenceQueue refqueue = new ReferenceQueue<>(); + ConcurrentMap refs = + new ConcurrentHashMap<>(); + + ManagedChannelOrphanWrapper wrapper = new ManagedChannelOrphanWrapper(mc, refqueue, refs); + WeakReference wrapperWeakRef = new WeakReference<>(wrapper); + + final List records = new ArrayList<>(); + Logger orphanLogger = Logger.getLogger(ManagedChannelOrphanWrapper.class.getName()); + Filter oldFilter = orphanLogger.getFilter(); + orphanLogger.setFilter(new Filter() { + @Override + public boolean isLoggable(LogRecord record) { + synchronized (records) { + records.add(record); + } + return false; + } + }); + + try { + wrapper.shutdown(); + wrapper = null; + + // Wait for the WRAPPER itself to be garbage collected + GcFinalization.awaitClear(wrapperWeakRef); + ManagedChannelReference.cleanQueue(refqueue); + + synchronized (records) { + assertEquals("Warning was logged even though shutdownNow() was called!", 0, records.size()); + } + } finally { + orphanLogger.setFilter(oldFilter); + } + } + @Test public void refCycleIsGCed() { ReferenceQueue refqueue = diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java index 493714dfd41..fefc37e4fdc 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java @@ -20,6 +20,7 @@ import static io.grpc.MethodDescriptor.MethodType.UNARY; import static io.grpc.Status.Code.UNAVAILABLE; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; @@ -34,19 +35,13 @@ import io.grpc.testing.TestMethodDescriptors; import java.util.Collections; import java.util.Map; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class ManagedChannelServiceConfigTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void managedChannelServiceConfig_shouldParseHealthCheckingConfig() throws Exception { Map rawServiceConfig = @@ -79,10 +74,9 @@ public void createManagedChannelServiceConfig_failsOnDuplicateMethod() { Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name1, name2)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Duplicate method"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("Duplicate method name service/method"); } @Test @@ -92,10 +86,9 @@ public void createManagedChannelServiceConfig_failsOnDuplicateService() { Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name1, name2)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Duplicate service"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("Duplicate service service"); } @Test @@ -107,10 +100,9 @@ public void createManagedChannelServiceConfig_failsOnDuplicateServiceMultipleCon Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig1, methodConfig2)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Duplicate service"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("Duplicate service service"); } @Test @@ -119,10 +111,9 @@ public void createManagedChannelServiceConfig_failsOnMethodNameWithEmptyServiceN Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("missing service name for method method1"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("missing service name for method method1"); } @Test @@ -131,10 +122,9 @@ public void createManagedChannelServiceConfig_failsOnMethodNameWithoutServiceNam Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("missing service name for method method1"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("missing service name for method method1"); } @Test @@ -143,10 +133,9 @@ public void createManagedChannelServiceConfig_failsOnMissingServiceName() { Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("missing service"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("missing service name for method method"); } @Test diff --git a/core/src/test/java/io/grpc/internal/ManagedClientTransportTest.java b/core/src/test/java/io/grpc/internal/ManagedClientTransportTest.java index 0af88a62728..5ddea08131b 100644 --- a/core/src/test/java/io/grpc/internal/ManagedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedClientTransportTest.java @@ -32,7 +32,7 @@ public class ManagedClientTransportTest { public void testListener() { ManagedClientTransport.Listener listener = new ManagedClientTransport.Listener() { @Override - public void transportShutdown(Status s) {} + public void transportShutdown(Status s, DisconnectError e) {} @Override public void transportTerminated() {} @@ -45,7 +45,7 @@ public void transportInUse(boolean inUse) {} }; // Test that the listener methods do not throw. - listener.transportShutdown(Status.OK); + listener.transportShutdown(Status.OK, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); listener.transportTerminated(); listener.transportReady(); listener.transportInUse(true); diff --git a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java index 1ec1ccb2082..54758bc096f 100644 --- a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java @@ -20,6 +20,7 @@ import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeTrue; import static org.mockito.ArgumentMatchers.anyInt; @@ -53,10 +54,8 @@ import java.util.concurrent.TimeUnit; import java.util.zip.GZIPOutputStream; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; import org.junit.experimental.runners.Enclosed; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.junit.runners.Parameterized; @@ -133,7 +132,7 @@ public void simplePayload() { assertEquals(Bytes.asList(new byte[]{3, 14}), bytes(producer.getValue().next())); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock, 2, 2); + checkStats(tracer, transportTracer.getStats(), fakeClock, useGzipInflatingBuffer, 2, 2); } @Test @@ -148,7 +147,7 @@ public void smallCombinedPayloads() { verify(listener, atLeastOnce()).bytesRead(anyInt()); assertEquals(Bytes.asList(new byte[]{14, 15}), bytes(streams.get(1).next())); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock, 1, 1, 2, 2); + checkStats(tracer, transportTracer.getStats(), fakeClock, useGzipInflatingBuffer, 1, 1, 2, 2); } @Test @@ -162,7 +161,7 @@ public void endOfStreamWithPayloadShouldNotifyEndOfStream() { verify(listener).deframerClosed(false); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock, 1, 1); + checkStats(tracer, transportTracer.getStats(), fakeClock, useGzipInflatingBuffer, 1, 1); } @Test @@ -177,7 +176,7 @@ public void endOfStreamShouldNotifyEndOfStream() { } verify(listener).deframerClosed(false); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock); + checkStats(tracer, transportTracer.getStats(), fakeClock, false); } @Test @@ -189,7 +188,7 @@ public void endOfStreamWithPartialMessageShouldNotifyDeframerClosedWithPartialMe verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener).deframerClosed(true); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock); + checkStats(tracer, transportTracer.getStats(), fakeClock, false); } @Test @@ -206,7 +205,7 @@ public void endOfStreamWithInvalidGzipBlockShouldNotifyDeframerClosedWithPartial deframer.closeWhenComplete(); verify(listener).deframerClosed(true); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock); + checkStats(tracer, transportTracer.getStats(), fakeClock, false); } @Test @@ -228,10 +227,11 @@ public void payloadSplitBetweenBuffers() { tracer, transportTracer.getStats(), fakeClock, + true, 7 /* msg size */ + 2 /* second buffer adds two bytes of overhead in deflate block */, 7); } else { - checkStats(tracer, transportTracer.getStats(), fakeClock, 7, 7); + checkStats(tracer, transportTracer.getStats(), fakeClock, false, 7, 7); } } @@ -248,7 +248,7 @@ public void frameHeaderSplitBetweenBuffers() { assertEquals(Bytes.asList(new byte[]{3}), bytes(producer.getValue().next())); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock, 1, 1); + checkStats(tracer, transportTracer.getStats(), fakeClock, useGzipInflatingBuffer, 1, 1); } @Test @@ -259,7 +259,7 @@ public void emptyPayload() { assertEquals(Bytes.asList(), bytes(producer.getValue().next())); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock, 0, 0); + checkStats(tracer, transportTracer.getStats(), fakeClock, useGzipInflatingBuffer, 0, 0); } @Test @@ -273,9 +273,10 @@ public void largerFrameSize() { verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); if (useGzipInflatingBuffer) { - checkStats(tracer, transportTracer.getStats(), fakeClock, 8 /* compressed size */, 1000); + checkStats(tracer, transportTracer.getStats(), fakeClock,true, + 8 /* compressed size */, 1000); } else { - checkStats(tracer, transportTracer.getStats(), fakeClock, 1000, 1000); + checkStats(tracer, transportTracer.getStats(), fakeClock, false, 1000, 1000); } } @@ -292,7 +293,7 @@ public void endOfStreamCallbackShouldWaitForMessageDelivery() { verify(listener).deframerClosed(false); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock, 1, 1); + checkStats(tracer, transportTracer.getStats(), fakeClock, useGzipInflatingBuffer, 1, 1); } @Test @@ -308,6 +309,7 @@ public void compressed() { verify(listener).messagesAvailable(producer.capture()); assertEquals(Bytes.asList(new byte[1000]), bytes(producer.getValue().next())); verify(listener, atLeastOnce()).bytesRead(anyInt()); + checkStats(tracer, transportTracer.getStats(), fakeClock, true, 29, 1000); verifyNoMoreInteractions(listener); } @@ -338,9 +340,6 @@ public Void answer(InvocationOnMock invocation) throws Throwable { @RunWith(JUnit4.class) public static class SizeEnforcingInputStreamTests { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); private TestBaseStreamTracer tracer = new TestBaseStreamTracer(); private StatsTraceContext statsTraceCtx = new StatsTraceContext(new StreamTracer[]{tracer}); @@ -378,11 +377,12 @@ public void sizeEnforcingInputStream_readByteAboveLimit() throws IOException { new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx); try { - thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); - - while (stream.read() != -1) { - } + StatusRuntimeException e = assertThrows(StatusRuntimeException.class, () -> { + while (stream.read() != -1) { + } + }); + assertThat(e).hasMessageThat() + .isEqualTo("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds maximum size 2"); } finally { stream.close(); } @@ -424,10 +424,10 @@ public void sizeEnforcingInputStream_readAboveLimit() throws IOException { byte[] buf = new byte[10]; try { - thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); - - stream.read(buf, 0, buf.length); + StatusRuntimeException e = assertThrows(StatusRuntimeException.class, + () -> stream.read(buf, 0, buf.length)); + assertThat(e).hasMessageThat() + .isEqualTo("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds maximum size 2"); } finally { stream.close(); } @@ -467,10 +467,9 @@ public void sizeEnforcingInputStream_skipAboveLimit() throws IOException { new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx); try { - thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); - - stream.skip(4); + StatusRuntimeException e = assertThrows(StatusRuntimeException.class, () -> stream.skip(4)); + assertThat(e).hasMessageThat() + .isEqualTo("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds maximum size 2"); } finally { stream.close(); } @@ -502,7 +501,8 @@ public void sizeEnforcingInputStream_markReset() throws IOException { * @param sizes in the format {wire0, uncompressed0, wire1, uncompressed1, ...} */ private static void checkStats( - TestBaseStreamTracer tracer, TransportStats transportStats, FakeClock clock, long... sizes) { + TestBaseStreamTracer tracer, TransportStats transportStats, FakeClock clock, + boolean compressed, long... sizes) { assertEquals(0, sizes.length % 2); int count = sizes.length / 2; long expectedWireSize = 0; @@ -510,7 +510,8 @@ private static void checkStats( for (int i = 0; i < count; i++) { assertEquals("inboundMessage(" + i + ")", tracer.nextInboundEvent()); assertEquals( - String.format(Locale.US, "inboundMessageRead(%d, %d, -1)", i, sizes[i * 2]), + String.format(Locale.US, "inboundMessageRead(%d, %d, %d)", i, sizes[i * 2], + compressed ? -1 : sizes[i * 2 + 1]), tracer.nextInboundEvent()); expectedWireSize += sizes[i * 2]; expectedUncompressedSize += sizes[i * 2 + 1]; diff --git a/core/src/test/java/io/grpc/internal/MetricRecorderImplTest.java b/core/src/test/java/io/grpc/internal/MetricRecorderImplTest.java index 08f34a267f9..33bf9bb41e2 100644 --- a/core/src/test/java/io/grpc/internal/MetricRecorderImplTest.java +++ b/core/src/test/java/io/grpc/internal/MetricRecorderImplTest.java @@ -32,6 +32,7 @@ import io.grpc.LongCounterMetricInstrument; import io.grpc.LongGaugeMetricInstrument; import io.grpc.LongHistogramMetricInstrument; +import io.grpc.LongUpDownCounterMetricInstrument; import io.grpc.MetricInstrumentRegistry; import io.grpc.MetricInstrumentRegistryAccessor; import io.grpc.MetricRecorder; @@ -79,6 +80,9 @@ public class MetricRecorderImplTest { private final LongGaugeMetricInstrument longGaugeInstrument = registry.registerLongGauge("gauge0", DESCRIPTION, UNIT, REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, ENABLED); + private final LongUpDownCounterMetricInstrument longUpDownCounterInstrument = + registry.registerLongUpDownCounter("upDownCounter0", DESCRIPTION, UNIT, + REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, ENABLED); private MetricRecorder recorder; @Before @@ -88,7 +92,7 @@ public void setUp() { @Test public void addCounter() { - when(mockSink.getMeasuresSize()).thenReturn(4); + when(mockSink.getMeasuresSize()).thenReturn(6); recorder.addDoubleCounter(doubleCounterInstrument, 1.0, REQUIRED_LABEL_VALUES, OPTIONAL_LABEL_VALUES); @@ -100,6 +104,12 @@ public void addCounter() { verify(mockSink, times(2)).addLongCounter(eq(longCounterInstrument), eq(1L), eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); + recorder.addLongUpDownCounter(longUpDownCounterInstrument, -10, REQUIRED_LABEL_VALUES, + OPTIONAL_LABEL_VALUES); + verify(mockSink, times(2)) + .addLongUpDownCounter(eq(longUpDownCounterInstrument), eq(-10L), + eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); + verify(mockSink, never()).updateMeasures(registry.getMetricInstruments()); } @@ -190,6 +200,13 @@ public void newRegisteredMetricUpdateMeasures() { verify(mockSink, times(2)) .registerBatchCallback(any(Runnable.class), eq(longGaugeInstrument)); registration.close(); + + // Long UpDown Counter + recorder.addLongUpDownCounter(longUpDownCounterInstrument, -10, REQUIRED_LABEL_VALUES, + OPTIONAL_LABEL_VALUES); + verify(mockSink, times(12)).updateMeasures(anyList()); + verify(mockSink, times(2)).addLongUpDownCounter(eq(longUpDownCounterInstrument), eq(-10L), + eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); } @Test(expected = IllegalArgumentException.class) @@ -208,6 +225,13 @@ public void addLongCounterMismatchedRequiredLabelValues() { OPTIONAL_LABEL_VALUES); } + @Test(expected = IllegalArgumentException.class) + public void addLongUpDownCounterMismatchedRequiredLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(6); + recorder.addLongUpDownCounter(longUpDownCounterInstrument, 1, ImmutableList.of(), + OPTIONAL_LABEL_VALUES); + } + @Test(expected = IllegalArgumentException.class) public void recordDoubleHistogramMismatchedRequiredLabelValues() { when(mockSink.getMeasuresSize()).thenReturn(4); @@ -260,6 +284,13 @@ public void addLongCounterMismatchedOptionalLabelValues() { ImmutableList.of()); } + @Test(expected = IllegalArgumentException.class) + public void addLongUpDownCounterMismatchedOptionalLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(6); + recorder.addLongUpDownCounter(longUpDownCounterInstrument, 1, REQUIRED_LABEL_VALUES, + ImmutableList.of()); + } + @Test(expected = IllegalArgumentException.class) public void recordDoubleHistogramMismatchedOptionalLabelValues() { when(mockSink.getMeasuresSize()).thenReturn(4); diff --git a/core/src/test/java/io/grpc/internal/NoopClientStreamTest.java b/core/src/test/java/io/grpc/internal/NoopClientStreamTest.java new file mode 100644 index 00000000000..d68642dad85 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/NoopClientStreamTest.java @@ -0,0 +1,44 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.io.InputStream; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link NoopClientStream}. + */ +@RunWith(JUnit4.class) +public class NoopClientStreamTest { + + @Test + public void writeMessageShouldCloseInputStream() throws Exception { + // NoopClientStream.writeMessage() is called when a stream is cancelled or failed + // before the real transport stream is established (e.g. via DelayedStream draining + // buffered messages to NoopClientStream on cancellation, or FailingClientStream + // which extends NoopClientStream). The InputStream must be closed to avoid leaking + // resources such as ref-counted ByteBufs. + InputStream message = mock(InputStream.class); + NoopClientStream.INSTANCE.writeMessage(message); + verify(message).close(); + } +} diff --git a/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java b/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java index 63915bddc99..eb7b40257c0 100644 --- a/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java @@ -23,14 +23,18 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.InternalEquivalentAddressGroup.ATTR_WEIGHT; import static io.grpc.LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY; import static io.grpc.LoadBalancer.HEALTH_CONSUMER_LISTENER_ARG_KEY; import static io.grpc.LoadBalancer.IS_PETIOLE_POLICY; import static io.grpc.internal.PickFirstLeafLoadBalancer.CONNECTION_DELAY_INTERVAL_MS; +import static io.grpc.internal.PickFirstLeafLoadBalancer.isSerializingRetries; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assume.assumeTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -65,15 +69,18 @@ import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.internal.PickFirstLeafLoadBalancer.PickFirstLeafLoadBalancerConfig; +import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Queue; +import java.util.Random; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import org.junit.After; -import org.junit.Assume; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -92,14 +99,22 @@ public class PickFirstLeafLoadBalancerTest { public static final Status CONNECTION_ERROR = Status.UNAVAILABLE.withDescription("Simulated connection error"); - - @Parameterized.Parameters(name = "{0}") - public static List enableHappyEyeballs() { - return Arrays.asList(true, false); + public static final String GRPC_SERIALIZE_RETRIES = "GRPC_SERIALIZE_RETRIES"; + + @Parameterized.Parameters(name = "{0}-{1}") + public static List data() { + return Arrays.asList(new Object[][] { + {false, false}, + {false, true}, + {true, false}}); } - @Parameterized.Parameter + @Parameterized.Parameter(value = 0) + public boolean serializeRetries; + + @Parameterized.Parameter(value = 1) public boolean enableHappyEyeballs; + private PickFirstLeafLoadBalancer loadBalancer; private final List servers = Lists.newArrayList(); private static final Attributes.Key FOO = Attributes.Key.create("foo"); @@ -137,13 +152,25 @@ public void uncaughtException(Thread t, Throwable e) { private PickSubchannelArgs mockArgs; private String originalHappyEyeballsEnabledValue; + private String originalSerializeRetriesValue; + private boolean originalWeightedShuffling; + + private long backoffMillis; @Before public void setUp() { + assumeTrue(!serializeRetries || !enableHappyEyeballs); // they are not compatible + + backoffMillis = TimeUnit.SECONDS.toMillis(1); + originalSerializeRetriesValue = System.getProperty(GRPC_SERIALIZE_RETRIES); + System.setProperty(GRPC_SERIALIZE_RETRIES, Boolean.toString(serializeRetries)); + originalHappyEyeballsEnabledValue = System.getProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS); System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS, - enableHappyEyeballs ? "true" : "false"); + Boolean.toString(enableHappyEyeballs)); + + originalWeightedShuffling = PickFirstLeafLoadBalancer.weightedShuffling; for (int i = 1; i <= 5; i++) { SocketAddress addr = new FakeSocketAddress("server" + i); @@ -176,12 +203,18 @@ public void setUp() { @After public void tearDown() { + if (originalSerializeRetriesValue == null) { + System.clearProperty(GRPC_SERIALIZE_RETRIES); + } else { + System.setProperty(GRPC_SERIALIZE_RETRIES, originalSerializeRetriesValue); + } if (originalHappyEyeballsEnabledValue == null) { System.clearProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS); } else { System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS, originalHappyEyeballsEnabledValue); } + PickFirstLeafLoadBalancer.weightedShuffling = originalWeightedShuffling; loadBalancer.shutdown(); verifyNoMoreInteractions(mockArgs); @@ -217,6 +250,12 @@ public void pickAfterResolved() { verifyNoMoreInteractions(mockHelper); } + @Test + public void pickAfterResolved_shuffle_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffle(); + } + @Test public void pickAfterResolved_shuffle() { servers.remove(4); @@ -280,6 +319,103 @@ public void pickAfterResolved_noShuffle() { assertNotNull(pickerCaptor.getValue().pickSubchannel(mockArgs)); } + @Test + public void pickAfterResolved_shuffleImplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleImplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleImplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup(new FakeSocketAddress("server1")); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup(new FakeSocketAddress("server2")); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup(new FakeSocketAddress("server3")); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleExplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_noWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = false; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_weightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = true; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(75); // 100*12/16 + assertThat(counts[1]).isWithin(7).of(19); // 100*3/16 + assertThat(counts[2]).isWithin(7).of(6); // 100*1/16 + } + + /** Returns int[index_of_eag] array with number of times each eag was selected. */ + private int[] countAddressSelections(int trials, List eags) { + int[] counts = new int[eags.size()]; + Random random = new Random(1); + for (int i = 0; i < trials; i++) { + RecordingHelper helper = new RecordingHelper(); + LoadBalancer lb = new PickFirstLeafLoadBalancer(helper); + assertThat(lb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(eags) + .setAttributes(affinity) + .setLoadBalancingPolicyConfig( + new PickFirstLeafLoadBalancerConfig(true, random.nextLong())) + .build())) + .isSameInstanceAs(Status.OK); + helper.subchannels.remove().listener.onSubchannelState( + ConnectivityStateInfo.forNonError(READY)); + + assertThat(helper.state).isEqualTo(READY); + Subchannel subchannel = helper.picker.pickSubchannel(mockArgs).getSubchannel(); + counts[eags.indexOf(subchannel.getAddresses())]++; + + lb.shutdown(); + } + return counts; + } + @Test public void requestConnectionPicker() { // Set up @@ -498,6 +634,9 @@ public void healthCheckFlow() { inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs) .getSubchannel()).isSameInstanceAs(mockSubchannel1); + verify(mockHelper, atLeast(0)).getSynchronizationContext(); + verify(mockHelper, atLeast(0)).getScheduledExecutorService(); + verifyNoMoreInteractions(mockHelper); healthListener2.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); verifyNoMoreInteractions(mockHelper); @@ -520,20 +659,7 @@ public void pickAfterStateChangeAfterResolution() { inOrder.verify(mockSubchannel1).start(stateListenerCaptor.capture()); stateListeners[0] = stateListenerCaptor.getValue(); - if (enableHappyEyeballs) { - forwardTimeByConnectionDelay(); - inOrder.verify(mockSubchannel2).start(stateListenerCaptor.capture()); - stateListeners[1] = stateListenerCaptor.getValue(); - forwardTimeByConnectionDelay(); - inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture()); - stateListeners[2] = stateListenerCaptor.getValue(); - forwardTimeByConnectionDelay(); - inOrder.verify(mockSubchannel4).start(stateListenerCaptor.capture()); - stateListeners[3] = stateListenerCaptor.getValue(); - } - - reset(mockHelper); - + stateListeners[0].onSubchannelState(ConnectivityStateInfo.forNonError(READY)); stateListeners[0].onSubchannelState(ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).refreshNameResolution(); inOrder.verify(mockHelper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); @@ -543,11 +669,23 @@ public void pickAfterStateChangeAfterResolution() { stateListeners[0].onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); Status error = Status.UNAVAILABLE.withDescription("boom!"); + reset(mockHelper); if (enableHappyEyeballs) { - for (SubchannelStateListener listener : stateListeners) { - listener.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); - } + stateListeners[0].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + forwardTimeByConnectionDelay(); + inOrder.verify(mockSubchannel2).start(stateListenerCaptor.capture()); + stateListeners[1] = stateListenerCaptor.getValue(); + stateListeners[1].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + forwardTimeByConnectionDelay(); + inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture()); + stateListeners[2] = stateListenerCaptor.getValue(); + stateListeners[2].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + forwardTimeByConnectionDelay(); + inOrder.verify(mockSubchannel4).start(stateListenerCaptor.capture()); + stateListeners[3] = stateListenerCaptor.getValue(); + stateListeners[3].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + forwardTimeByConnectionDelay(); } else { stateListeners[0].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); for (int i = 1; i < stateListeners.length; i++) { @@ -589,6 +727,8 @@ public void pickAfterResolutionAfterTransientValue() { // Transition from TRANSIENT_ERROR to CONNECTING should also be ignored. stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + verify(mockHelper, atLeast(0)).getSynchronizationContext(); + verify(mockHelper, atLeast(0)).getScheduledExecutorService(); verifyNoMoreInteractions(mockHelper); assertEquals(error, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); } @@ -619,6 +759,8 @@ public void pickWithDupAddressesUpDownUp() { // Transition from TRANSIENT_ERROR to CONNECTING should also be ignored. stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + verify(mockHelper, atLeast(0)).getSynchronizationContext(); + verify(mockHelper, atLeast(0)).getScheduledExecutorService(); verifyNoMoreInteractions(mockHelper); assertEquals(error, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); @@ -651,6 +793,8 @@ public void pickWithDupEagsUpDownUp() { // Transition from TRANSIENT_ERROR to CONNECTING should also be ignored. stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + verify(mockHelper, atLeast(0)).getSynchronizationContext(); + verify(mockHelper, atLeast(0)).getScheduledExecutorService(); verifyNoMoreInteractions(mockHelper); assertEquals(error, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); @@ -1302,6 +1446,11 @@ public void updateAddresses_disjoint_transient_failure() { loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); + if (serializeRetries) { + inOrder.verify(mockSubchannel3, never()).start(stateListenerCaptor.capture()); + forwardTimeByBackoffDelay(); + } + // subchannel 3 still attempts a connection even though we stay in transient failure assertEquals(TRANSIENT_FAILURE, loadBalancer.getConcludedConnectivityState()); inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture()); @@ -1518,6 +1667,8 @@ public void updateAddresses_intersecting_ready() { @Test public void updateAddresses_intersecting_transient_failure() { + assumeTrue(!isSerializingRetries()); + // Starting first connection attempt InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3, mockSubchannel4); // captor: captures @@ -1782,6 +1933,8 @@ public void updateAddresses_identical_ready() { @Test public void updateAddresses_identical_transient_failure() { + assumeTrue(!isSerializingRetries()); + InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3, mockSubchannel4); // Creating first set of endpoints/addresses @@ -1835,6 +1988,45 @@ public void updateAddresses_identical_transient_failure() { assertEquals(PickResult.withSubchannel(mockSubchannel1), picker.pickSubchannel(mockArgs)); } + @Test + public void updateAddresses_identicalSingleAddress_connecting() { + // Creating first set of endpoints/addresses + List oldServers = Lists.newArrayList(servers.get(0)); + + // Accept Addresses and verify proper connection flow + assertEquals(IDLE, loadBalancer.getConcludedConnectivityState()); + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(oldServers).setAttributes(affinity).build()); + verify(mockSubchannel1).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener = stateListenerCaptor.getValue(); + assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); + + // First connection attempt is successful + stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); + fakeClock.forwardTime(CONNECTION_DELAY_INTERVAL_MS, TimeUnit.MILLISECONDS); + + // verify that picker returns no subchannel + verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + SubchannelPicker picker = pickerCaptor.getValue(); + assertEquals(PickResult.withNoResult(), picker.pickSubchannel(mockArgs)); + + // Accept same resolved addresses to update + reset(mockHelper); + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(oldServers).setAttributes(affinity).build()); + fakeClock.forwardTime(CONNECTION_DELAY_INTERVAL_MS, TimeUnit.MILLISECONDS); + + // Verify that no new subchannels were created or started + verify(mockSubchannel2, never()).start(any()); + assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); + + // verify that picker hasn't changed via checking mock helper's interactions + verify(mockHelper, atLeast(0)).getSynchronizationContext(); // Don't care + verify(mockHelper, atLeast(0)).getScheduledExecutorService(); + verifyNoMoreInteractions(mockHelper); + } + @Test public void twoAddressesSeriallyConnect() { // Starting first connection attempt @@ -2096,18 +2288,20 @@ public void lastAddressFailingNotTransientFailure() { loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); - // Verify that no new subchannels were created or started + // Subchannel 2 should be reused since it was trying to connect and is present. inOrder.verify(mockSubchannel1).shutdown(); - inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture()); - SubchannelStateListener stateListener3 = stateListenerCaptor.getValue(); - inOrder.verify(mockSubchannel3).requestConnection(); + inOrder.verify(mockSubchannel3, never()).start(stateListenerCaptor.capture()); assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); - // Second address connection attempt is unsuccessful, but should not go into transient failure + // Second address connection attempt is unsuccessful, so since at end, but don't have all + // subchannels, schedule a backoff for the first address stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(CONNECTION_ERROR)); + fakeClock.forwardTime(1, TimeUnit.SECONDS); + inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener3 = stateListenerCaptor.getValue(); assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); - // Third address connection attempt is unsuccessful, now we enter transient failure + // Third address connection attempt is unsuccessful, now we enter TF, do name resolution stateListener3.onSubchannelState(ConnectivityStateInfo.forTransientFailure(CONNECTION_ERROR)); assertEquals(TRANSIENT_FAILURE, loadBalancer.getConcludedConnectivityState()); @@ -2295,7 +2489,7 @@ public void ready_then_transient_failure_again() { @Test public void happy_eyeballs_trigger_connection_delay() { - Assume.assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs + assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs // Starting first connection attempt InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3, mockSubchannel4); @@ -2340,7 +2534,7 @@ public void happy_eyeballs_trigger_connection_delay() { @Test public void happy_eyeballs_connection_results_happen_after_get_to_end() { - Assume.assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs + assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3); Status error = Status.UNAUTHENTICATED.withDescription("simulated failure"); @@ -2393,7 +2587,7 @@ public void happy_eyeballs_connection_results_happen_after_get_to_end() { @Test public void happy_eyeballs_pick_pushes_index_over_end() { - Assume.assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs + assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3, mockSubchannel2n2, mockSubchannel3n2); @@ -2471,7 +2665,7 @@ public void happy_eyeballs_pick_pushes_index_over_end() { @Test public void happy_eyeballs_fail_then_trigger_connection_delay() { - Assume.assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs + assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs // Starting first connection attempt InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3); assertEquals(IDLE, loadBalancer.getConcludedConnectivityState()); @@ -2550,6 +2744,44 @@ public void advance_index_then_request_connection() { loadBalancer.requestConnection(); // should be handled without throwing exception } + @Test + public void serialized_retries_two_passes() { + assumeTrue(serializeRetries); // This test is only for serialized retries + + InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3); + Status error = Status.UNAUTHENTICATED.withDescription("simulated failure"); + + List addrs = + Lists.newArrayList(servers.get(0), servers.get(1), servers.get(2)); + Subchannel[] subchannels = new Subchannel[]{mockSubchannel1, mockSubchannel2, mockSubchannel3}; + SubchannelStateListener[] listeners = new SubchannelStateListener[subchannels.length]; + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(addrs).build()); + forwardTimeByConnectionDelay(2); + for (int i = 0; i < subchannels.length; i++) { + inOrder.verify(subchannels[i]).start(stateListenerCaptor.capture()); + inOrder.verify(subchannels[i]).requestConnection(); + listeners[i] = stateListenerCaptor.getValue(); + listeners[i].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + } + assertEquals(TRANSIENT_FAILURE, loadBalancer.getConcludedConnectivityState()); + assertFalse("Index should be at end", loadBalancer.isIndexValid()); + + forwardTimeByBackoffDelay(); // should trigger retry + for (int i = 0; i < subchannels.length; i++) { + inOrder.verify(subchannels[i]).requestConnection(); + listeners[i].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); // cascade + } + inOrder.verify(subchannels[0], never()).requestConnection(); // should wait for backoff delay + + forwardTimeByBackoffDelay(); // should trigger retry again + for (int i = 0; i < subchannels.length; i++) { + inOrder.verify(subchannels[i]).requestConnection(); + assertEquals(i, loadBalancer.getIndexLocation()); + listeners[i].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); // cascade + } + } + @Test public void index_looping() { Attributes.Key key = Attributes.Key.create("some-key"); @@ -2564,7 +2796,7 @@ public void index_looping() { PickFirstLeafLoadBalancer.Index index = new PickFirstLeafLoadBalancer.Index(Arrays.asList( new EquivalentAddressGroup(Arrays.asList(addr1, addr2), attr1), new EquivalentAddressGroup(Arrays.asList(addr3), attr2), - new EquivalentAddressGroup(Arrays.asList(addr4, addr5), attr3))); + new EquivalentAddressGroup(Arrays.asList(addr4, addr5), attr3)), enableHappyEyeballs); assertThat(index.getCurrentAddress()).isSameInstanceAs(addr1); assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attr1); assertThat(index.isAtBeginning()).isTrue(); @@ -2623,7 +2855,7 @@ public void index_updateGroups_resets() { SocketAddress addr3 = new FakeSocketAddress("addr3"); PickFirstLeafLoadBalancer.Index index = new PickFirstLeafLoadBalancer.Index(Arrays.asList( new EquivalentAddressGroup(Arrays.asList(addr1)), - new EquivalentAddressGroup(Arrays.asList(addr2, addr3)))); + new EquivalentAddressGroup(Arrays.asList(addr2, addr3))), enableHappyEyeballs); index.increment(); index.increment(); // We want to make sure both groupIndex and addressIndex are reset @@ -2640,7 +2872,7 @@ public void index_seekTo() { SocketAddress addr3 = new FakeSocketAddress("addr3"); PickFirstLeafLoadBalancer.Index index = new PickFirstLeafLoadBalancer.Index(Arrays.asList( new EquivalentAddressGroup(Arrays.asList(addr1, addr2)), - new EquivalentAddressGroup(Arrays.asList(addr3)))); + new EquivalentAddressGroup(Arrays.asList(addr3))), enableHappyEyeballs); assertThat(index.seekTo(addr3)).isTrue(); assertThat(index.getCurrentAddress()).isSameInstanceAs(addr3); assertThat(index.seekTo(addr1)).isTrue(); @@ -2652,6 +2884,83 @@ public void index_seekTo() { assertThat(index.getCurrentAddress()).isSameInstanceAs(addr2); } + @Test + public void index_interleaving() { + InetSocketAddress addr1_6 = new InetSocketAddress("f38:1:1", 1234); + InetSocketAddress addr1_4 = new InetSocketAddress("10.1.1.1", 1234); + InetSocketAddress addr2_4 = new InetSocketAddress("10.1.1.2", 1234); + InetSocketAddress addr3_4 = new InetSocketAddress("10.1.1.3", 1234); + InetSocketAddress addr4_4 = new InetSocketAddress("10.1.1.4", 1234); + InetSocketAddress addr4_6 = new InetSocketAddress("f38:1:4", 1234); + + Attributes attrs1 = Attributes.newBuilder().build(); + Attributes attrs2 = Attributes.newBuilder().build(); + Attributes attrs3 = Attributes.newBuilder().build(); + Attributes attrs4 = Attributes.newBuilder().build(); + + PickFirstLeafLoadBalancer.Index index = new PickFirstLeafLoadBalancer.Index(Arrays.asList( + new EquivalentAddressGroup(Arrays.asList(addr1_4, addr1_6), attrs1), + new EquivalentAddressGroup(Arrays.asList(addr2_4), attrs2), + new EquivalentAddressGroup(Arrays.asList(addr3_4), attrs3), + new EquivalentAddressGroup(Arrays.asList(addr4_4, addr4_6), attrs4)), enableHappyEyeballs); + + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr1_4); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs1); + assertThat(index.isAtBeginning()).isTrue(); + + index.increment(); + assertThat(index.isValid()).isTrue(); + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr1_6); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs1); + assertThat(index.isAtBeginning()).isFalse(); + + index.increment(); + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr2_4); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs2); + + index.increment(); + if (enableHappyEyeballs) { + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr4_6); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs4); + } else { + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr3_4); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs3); + } + + index.increment(); + if (enableHappyEyeballs) { + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr3_4); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs3); + } else { + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr4_4); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs4); + } + + // Move to last entry + assertThat(index.increment()).isTrue(); + assertThat(index.isValid()).isTrue(); + if (enableHappyEyeballs) { + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr4_4); + } else { + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr4_6); + } + + // Move off of the end + assertThat(index.increment()).isFalse(); + assertThat(index.isValid()).isFalse(); + assertThrows(IllegalStateException.class, index::getCurrentAddress); + + // Reset + index.reset(); + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr1_4); + assertThat(index.isAtBeginning()).isTrue(); + assertThat(index.isValid()).isTrue(); + + // Seek to an address + assertThat(index.seekTo(addr4_4)).isTrue(); + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr4_4); + } + private static class FakeSocketAddress extends SocketAddress { final String name; @@ -2689,6 +2998,11 @@ private void forwardTimeByConnectionDelay(int times) { } } + private void forwardTimeByBackoffDelay() { + backoffMillis = (long) (backoffMillis * 1.8); // backoff factor default is 1.6 with Jitter .2 + fakeClock.forwardTime(backoffMillis, TimeUnit.MILLISECONDS); + } + private void acceptXSubchannels(int num) { List newServers = new ArrayList<>(); for (int i = 0; i < num; i++) { @@ -2747,13 +3061,7 @@ public String toString() { } } - private class MockHelperImpl extends LoadBalancer.Helper { - private final List subchannels; - - public MockHelperImpl(List subchannels) { - this.subchannels = new ArrayList(subchannels); - } - + private class BaseHelper extends LoadBalancer.Helper { @Override public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { return null; @@ -2783,6 +3091,14 @@ public ScheduledExecutorService getScheduledExecutorService() { public void refreshNameResolution() { // noop } + } + + private class MockHelperImpl extends BaseHelper { + private final List subchannels; + + public MockHelperImpl(List subchannels) { + this.subchannels = new ArrayList(subchannels); + } @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { @@ -2799,4 +3115,23 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { throw new IllegalArgumentException("Unexpected addresses: " + args.getAddresses()); } } + + class RecordingHelper extends BaseHelper { + ConnectivityState state; + SubchannelPicker picker; + final Queue subchannels = new ArrayDeque<>(); + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + this.state = newState; + this.picker = newPicker; + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + FakeSubchannel subchannel = new FakeSubchannel(args.getAddresses(), args.getAttributes()); + subchannels.add(subchannel); + return subchannel; + } + } } diff --git a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java index 3e0258f2e40..1e130423a45 100644 --- a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java @@ -21,6 +21,7 @@ import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.InternalEquivalentAddressGroup.ATTR_WEIGHT; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; @@ -49,12 +50,18 @@ import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.ManagedChannel; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig; import java.net.SocketAddress; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Queue; +import java.util.Random; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -103,8 +110,12 @@ public void uncaughtException(Thread t, Throwable e) { @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown(). private PickSubchannelArgs mockArgs; + private boolean originalWeightedShuffling; + @Before public void setUp() { + originalWeightedShuffling = PickFirstLeafLoadBalancer.weightedShuffling; + for (int i = 0; i < 3; i++) { SocketAddress addr = new FakeSocketAddress("server" + i); servers.add(new EquivalentAddressGroup(addr)); @@ -120,6 +131,7 @@ public void setUp() { @After public void tearDown() throws Exception { + PickFirstLeafLoadBalancer.weightedShuffling = originalWeightedShuffling; verifyNoMoreInteractions(mockArgs); } @@ -141,6 +153,12 @@ public void pickAfterResolved() throws Exception { verifyNoMoreInteractions(mockHelper); } + @Test + public void pickAfterResolved_shuffle_oppositeWeightedShuffling() throws Exception { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffle(); + } + @Test public void pickAfterResolved_shuffle() throws Exception { loadBalancer.acceptResolvedAddresses( @@ -184,6 +202,103 @@ public void pickAfterResolved_noShuffle() throws Exception { verifyNoMoreInteractions(mockHelper); } + @Test + public void pickAfterResolved_shuffleImplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleImplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleImplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup(new FakeSocketAddress("server1")); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup(new FakeSocketAddress("server2")); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup(new FakeSocketAddress("server3")); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleExplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_noWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = false; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_weightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = true; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(75); // 100*12/16 + assertThat(counts[1]).isWithin(7).of(19); // 100*3/16 + assertThat(counts[2]).isWithin(7).of(6); // 100*1/16 + } + + /** Returns int[index_of_eag] array with number of times each eag was selected. */ + private int[] countAddressSelections(int trials, List eags) { + int[] counts = new int[eags.size()]; + Random random = new Random(1); + for (int i = 0; i < trials; i++) { + RecordingHelper helper = new RecordingHelper(); + PickFirstLoadBalancer lb = new PickFirstLoadBalancer(helper); + assertThat(lb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(eags) + .setAttributes(affinity) + .setLoadBalancingPolicyConfig( + new PickFirstLoadBalancerConfig(true, random.nextLong())) + .build())) + .isSameInstanceAs(Status.OK); + helper.subchannels.remove().listener.onSubchannelState( + ConnectivityStateInfo.forNonError(READY)); + + assertThat(helper.state).isEqualTo(READY); + Subchannel subchannel = helper.picker.pickSubchannel(mockArgs).getSubchannel(); + counts[eags.indexOf(subchannel.getAllAddresses().get(0))]++; + + lb.shutdown(); + } + return counts; + } + @Test public void requestConnectionPicker() throws Exception { loadBalancer.acceptResolvedAddresses( @@ -219,7 +334,7 @@ public void refreshNameResolutionAfterSubchannelConnectionBroken() { inOrder.verify(mockSubchannel).start(stateListenerCaptor.capture()); SubchannelStateListener stateListener = stateListenerCaptor.getValue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - assertSame(mockSubchannel, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs).hasResult()).isFalse(); inOrder.verify(mockSubchannel).requestConnection(); stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); @@ -278,7 +393,7 @@ public void pickAfterResolvedAndChanged() throws Exception { assertThat(args.getAddresses()).isEqualTo(servers); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verify(mockSubchannel).requestConnection(); - assertEquals(mockSubchannel, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs).hasResult()).isFalse(); loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); @@ -300,7 +415,7 @@ public void pickAfterStateChangeAfterResolution() throws Exception { verify(mockSubchannel).start(stateListenerCaptor.capture()); SubchannelStateListener stateListener = stateListenerCaptor.getValue(); verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - Subchannel subchannel = pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel(); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs).hasResult()).isFalse(); reset(mockHelper); when(mockHelper.getSynchronizationContext()).thenReturn(syncContext); @@ -317,7 +432,7 @@ public void pickAfterStateChangeAfterResolution() throws Exception { stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); - assertEquals(subchannel, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); + assertEquals(mockSubchannel, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); verify(mockHelper, atLeast(0)).getSynchronizationContext(); // Don't care verifyNoMoreInteractions(mockHelper); @@ -405,8 +520,7 @@ public void nameResolutionSuccessAfterError() throws Exception { inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verify(mockSubchannel).requestConnection(); - assertEquals(mockSubchannel, pickerCaptor.getValue().pickSubchannel(mockArgs) - .getSubchannel()); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs).hasResult()).isFalse(); assertEquals(pickerCaptor.getValue().pickSubchannel(mockArgs), pickerCaptor.getValue().pickSubchannel(mockArgs)); @@ -487,4 +601,96 @@ public String toString() { return "FakeSocketAddress-" + name; } } + + private static class FakeSubchannel extends Subchannel { + private final Attributes attributes; + private List eags; + private SubchannelStateListener listener; + + public FakeSubchannel(List eags, Attributes attributes) { + this.eags = Collections.unmodifiableList(eags); + this.attributes = attributes; + } + + @Override + public List getAllAddresses() { + return eags; + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + @Override + public void start(SubchannelStateListener listener) { + this.listener = listener; + } + + @Override + public void updateAddresses(List addrs) { + this.eags = Collections.unmodifiableList(addrs); + } + + @Override + public void shutdown() { + listener.onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.SHUTDOWN)); + } + + @Override + public void requestConnection() { + } + + @Override + public String toString() { + return "FakeSubchannel@" + hashCode() + "(" + eags + ")"; + } + } + + private class BaseHelper extends Helper { + @Override + public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { + return null; + } + + @Override + public String getAuthority() { + return null; + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + // ignore + } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + + @Override + public void refreshNameResolution() { + // noop + } + } + + class RecordingHelper extends BaseHelper { + ConnectivityState state; + SubchannelPicker picker; + final Queue subchannels = new ArrayDeque<>(); + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + this.state = newState; + this.picker = newPicker; + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + FakeSubchannel subchannel = new FakeSubchannel(args.getAddresses(), args.getAttributes()); + subchannels.add(subchannel); + return subchannel; + } + } + } diff --git a/core/src/test/java/io/grpc/internal/ProxyDetectorImplTest.java b/core/src/test/java/io/grpc/internal/ProxyDetectorImplTest.java index 0432a474ac5..771050f119d 100644 --- a/core/src/test/java/io/grpc/internal/ProxyDetectorImplTest.java +++ b/core/src/test/java/io/grpc/internal/ProxyDetectorImplTest.java @@ -73,7 +73,7 @@ public ProxySelector get() { return proxySelector; } }; - proxyDetector = new ProxyDetectorImpl(proxySelectorSupplier, authenticator, null); + proxyDetector = new ProxyDetectorImpl(proxySelectorSupplier, authenticator); unresolvedProxy = InetSocketAddress.createUnresolved("10.0.0.1", proxyPort); proxySocketAddress = HttpConnectProxiedSocketAddress.newBuilder() .setTargetAddress(destination) @@ -82,45 +82,6 @@ public ProxySelector get() { .build(); } - @Test - public void override_hostPort() throws Exception { - final String overrideHost = "10.99.99.99"; - final int overridePort = 1234; - final String overrideHostWithPort = overrideHost + ":" + overridePort; - ProxyDetectorImpl proxyDetector = new ProxyDetectorImpl( - proxySelectorSupplier, - authenticator, - overrideHostWithPort); - ProxiedSocketAddress detected = proxyDetector.proxyFor(destination); - assertNotNull(detected); - assertEquals( - HttpConnectProxiedSocketAddress.newBuilder() - .setTargetAddress(destination) - .setProxyAddress( - new InetSocketAddress(InetAddress.getByName(overrideHost), overridePort)) - .build(), - detected); - } - - @Test - public void override_hostOnly() throws Exception { - final String overrideHostWithoutPort = "10.99.99.99"; - final int defaultPort = 80; - ProxyDetectorImpl proxyDetector = new ProxyDetectorImpl( - proxySelectorSupplier, - authenticator, - overrideHostWithoutPort); - ProxiedSocketAddress detected = proxyDetector.proxyFor(destination); - assertNotNull(detected); - assertEquals( - HttpConnectProxiedSocketAddress.newBuilder() - .setTargetAddress(destination) - .setProxyAddress( - new InetSocketAddress(InetAddress.getByName(overrideHostWithoutPort), defaultPort)) - .build(), - detected); - } - @Test public void returnNullWhenNoProxy() throws Exception { when(proxySelector.select(any(URI.class))) @@ -227,8 +188,7 @@ public ProxySelector get() { return null; } }, - authenticator, - null); + authenticator); assertNull(proxyDetector.proxyFor(destination)); } } diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 21ec46668fc..afbdaa395b0 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -147,6 +147,17 @@ public double nextDouble() { private final ChannelBufferMeter channelBufferUsed = new ChannelBufferMeter(); private final FakeClock fakeClock = new FakeClock(); + private static long calculateBackoffWithRetries(int retryCount) { + // Calculate the exponential backoff delay with jitter + double exponent = retryCount > 0 ? Math.pow(BACKOFF_MULTIPLIER, retryCount) : 1; + long delay = (long) (INITIAL_BACKOFF_IN_SECONDS * exponent); + return RetriableStream.intervalWithJitter(delay); + } + + private static long calculateMaxBackoff() { + return RetriableStream.intervalWithJitter(MAX_BACKOFF_IN_SECONDS); + } + private final class RecordedRetriableStream extends RetriableStream { RecordedRetriableStream(MethodDescriptor method, Metadata headers, ChannelBufferMeter channelBufferUsed, long perRpcBufferLimit, long channelBufferLimit, @@ -175,7 +186,8 @@ ClientStream newSubstream( Metadata metadata, ClientStreamTracer.Factory tracerFactory, int previousAttempts, - boolean isTransparentRetry) { + boolean isTransparentRetry, + boolean isHedgedStream) { bufferSizeTracer = tracerFactory.newClientStreamTracer(STREAM_INFO, metadata); int actualPreviousRpcAttemptsInHeader = metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS) == null @@ -307,7 +319,7 @@ public Void answer(InvocationOnMock in) { retriableStream.sendMessage("msg1 during backoff1"); retriableStream.sendMessage("msg2 during backoff1"); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM) - 1L, TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0) - 1L, TimeUnit.SECONDS); inOrder.verifyNoMoreInteractions(); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); @@ -364,9 +376,7 @@ public Void answer(InvocationOnMock in) { retriableStream.sendMessage("msg2 during backoff2"); retriableStream.sendMessage("msg3 during backoff2"); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * FAKE_RANDOM) - 1L, - TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(1) - 1L, TimeUnit.SECONDS); inOrder.verifyNoMoreInteractions(); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); @@ -459,7 +469,7 @@ public void retry_headersRead_cancel() { sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -518,7 +528,7 @@ public void retry_headersRead_closed() { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -584,7 +594,7 @@ public void retry_cancel_closed() { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -687,7 +697,7 @@ public void retry_unretriableClosed_cancel() { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -821,7 +831,7 @@ public boolean isReady() { // send more requests during backoff retriableStream.request(789); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); inOrder.verify(mockStream2).start(sublistenerCaptor2.get()); inOrder.verify(mockStream2).request(3); @@ -875,7 +885,7 @@ public void request(int numMessages) { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); inOrder.verify(mockStream2).request(3); @@ -920,7 +930,7 @@ public void start(ClientStreamListener listener) { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); inOrder.verify(retriableStreamRecorder).postCommit(); @@ -1028,7 +1038,7 @@ public boolean isReady() { retriableStream.request(789); readiness.add(retriableStream.isReady()); // expected false b/c in backoff - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); verify(mockStream2).start(any(ClientStreamListener.class)); readiness.add(retriableStream.isReady()); // expected true @@ -1110,7 +1120,7 @@ public void addPrevRetryAttemptsToRespHeaders() { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -1160,13 +1170,12 @@ public void start(ClientStreamListener listener) { listener1.closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); // send requests during backoff retriableStream.request(3); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(1), TimeUnit.SECONDS); retriableStream.request(1); verify(mockStream1, never()).request(anyInt()); @@ -1207,7 +1216,7 @@ public void start(ClientStreamListener listener) { // retry listener1.closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); verify(mockStream2).start(any(ClientStreamListener.class)); verify(retriableStreamRecorder).postCommit(); @@ -1260,7 +1269,7 @@ public void perRpcBufferLimitExceededDuringBackoff() { bufferSizeTracer.outboundWireSize(2); verify(retriableStreamRecorder, never()).postCommit(); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); verify(mockStream2).start(any(ClientStreamListener.class)); verify(mockStream2).isReady(); @@ -1332,7 +1341,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM) - 1L, TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0) - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1347,9 +1356,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { sublistenerCaptor2.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_2), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * FAKE_RANDOM) - 1L, - TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(1) - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1364,10 +1371,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { sublistenerCaptor3.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * BACKOFF_MULTIPLIER * FAKE_RANDOM) - - 1L, - TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(2) - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1382,7 +1386,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { sublistenerCaptor4.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_2), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (MAX_BACKOFF_IN_SECONDS * FAKE_RANDOM) - 1L, TimeUnit.SECONDS); + fakeClock.forwardTime(calculateMaxBackoff() - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1397,7 +1401,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { sublistenerCaptor5.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_2), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (MAX_BACKOFF_IN_SECONDS * FAKE_RANDOM) - 1L, TimeUnit.SECONDS); + fakeClock.forwardTime(calculateMaxBackoff() - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1480,7 +1484,7 @@ public void pushback() { sublistenerCaptor3.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM) - 1L, TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0) - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1495,9 +1499,7 @@ public void pushback() { sublistenerCaptor4.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_2), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * FAKE_RANDOM) - 1L, - TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(1) - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1512,10 +1514,7 @@ public void pushback() { sublistenerCaptor5.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_2), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * BACKOFF_MULTIPLIER * FAKE_RANDOM) - - 1L, - TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(2) - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1804,7 +1803,7 @@ public void transparentRetry_onlyOnceOnRefused() { .closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), REFUSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); inOrder.verify(retriableStreamRecorder).newSubstream(1); ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -1907,7 +1906,7 @@ public void normalRetry_thenNoTransparentRetry_butNormalRetry() { .closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); inOrder.verify(retriableStreamRecorder).newSubstream(1); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -1923,8 +1922,7 @@ public void normalRetry_thenNoTransparentRetry_butNormalRetry() { .closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), REFUSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(1), TimeUnit.SECONDS); inOrder.verify(retriableStreamRecorder).newSubstream(2); ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -1960,7 +1958,7 @@ public void normalRetry_thenNoTransparentRetry_andNoMoreRetry() { .closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); inOrder.verify(retriableStreamRecorder).newSubstream(1); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -2592,9 +2590,7 @@ public void run() { .closed(Status.fromCode(NON_FATAL_STATUS_CODE_1), REFUSED, new Metadata()); } finally { transport2Lock.unlock(); - if (transport1Lock.tryLock()) { - transport1Lock.unlock(); - } + transport1Lock.unlock(); } } }, "Thread-transport2"); diff --git a/core/src/test/java/io/grpc/internal/RetryingNameResolverTest.java b/core/src/test/java/io/grpc/internal/RetryingNameResolverTest.java index 6347416f0ca..1da93f05fe2 100644 --- a/core/src/test/java/io/grpc/internal/RetryingNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/RetryingNameResolverTest.java @@ -17,7 +17,6 @@ package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -28,7 +27,6 @@ import io.grpc.NameResolver.ResolutionResult; import io.grpc.Status; import io.grpc.SynchronizationContext; -import io.grpc.internal.RetryingNameResolver.ResolutionResultListener; import java.lang.Thread.UncaughtExceptionHandler; import org.junit.Before; import org.junit.Rule; @@ -58,8 +56,6 @@ public class RetryingNameResolverTest { private RetryScheduler mockRetryScheduler; @Captor private ArgumentCaptor listenerCaptor; - @Captor - private ArgumentCaptor onResultCaptor; private final SynchronizationContext syncContext = new SynchronizationContext( mock(UncaughtExceptionHandler.class)); @@ -77,21 +73,14 @@ public void startAndShutdown() { retryingNameResolver.shutdown(); } - // Make sure the ResolutionResultListener callback is added to the ResolutionResult attributes, - // and the retry scheduler is reset since the name resolution was successful. @Test public void onResult_success() { + when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.OK); retryingNameResolver.start(mockListener); verify(mockNameResolver).start(listenerCaptor.capture()); listenerCaptor.getValue().onResult(ResolutionResult.newBuilder().build()); - verify(mockListener).onResult(onResultCaptor.capture()); - ResolutionResultListener resolutionResultListener = onResultCaptor.getValue() - .getAttributes() - .get(RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY); - assertThat(resolutionResultListener).isNotNull(); - resolutionResultListener.resolutionAttempted(Status.OK); verify(mockRetryScheduler).reset(); } @@ -107,21 +96,15 @@ public void onResult2_sucesss() { verify(mockRetryScheduler).reset(); } - // Make sure the ResolutionResultListener callback is added to the ResolutionResult attributes, - // and that a retry gets scheduled when the resolution results are rejected. + // Make sure that a retry gets scheduled when the resolution results are rejected. @Test public void onResult_failure() { + when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.UNAVAILABLE); retryingNameResolver.start(mockListener); verify(mockNameResolver).start(listenerCaptor.capture()); listenerCaptor.getValue().onResult(ResolutionResult.newBuilder().build()); - verify(mockListener).onResult(onResultCaptor.capture()); - ResolutionResultListener resolutionResultListener = onResultCaptor.getValue() - .getAttributes() - .get(RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY); - assertThat(resolutionResultListener).isNotNull(); - resolutionResultListener.resolutionAttempted(Status.UNAVAILABLE); verify(mockRetryScheduler).schedule(isA(Runnable.class)); } @@ -138,24 +121,6 @@ public void onResult2_failure() { verify(mockRetryScheduler).schedule(isA(Runnable.class)); } - // Wrapping a NameResolver more than once is a misconfiguration. - @Test - public void onResult_failure_doubleWrapped() { - NameResolver doubleWrappedResolver = new RetryingNameResolver(retryingNameResolver, - mockRetryScheduler, syncContext); - - doubleWrappedResolver.start(mockListener); - verify(mockNameResolver).start(listenerCaptor.capture()); - - try { - listenerCaptor.getValue().onResult(ResolutionResult.newBuilder().build()); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("can only be used once"); - return; - } - fail("An exception should have been thrown for a double wrapped NAmeResolver"); - } - // A retry should get scheduled when name resolution fails. @Test public void onError() { @@ -165,4 +130,4 @@ public void onError() { verify(mockListener).onError(Status.DEADLINE_EXCEEDED); verify(mockRetryScheduler).schedule(isA(Runnable.class)); } -} \ No newline at end of file +} diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java index 652c94a4640..7394c83eab2 100644 --- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java @@ -16,12 +16,14 @@ package io.grpc.internal; +import static com.google.common.truth.Truth.assertThat; import static io.grpc.internal.GrpcUtil.CONTENT_LENGTH_KEY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; @@ -54,7 +56,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -64,8 +65,6 @@ @RunWith(JUnit4.class) public class ServerCallImplTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Mock private ServerStream stream; @@ -175,20 +174,20 @@ public void sendHeader_contentLengthDiscarded() { @Test public void sendHeader_failsOnSecondCall() { call.sendHeaders(new Metadata()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("sendHeaders has already been called"); - - call.sendHeaders(new Metadata()); + Metadata headers = new Metadata(); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> call.sendHeaders(headers)); + assertThat(e).hasMessageThat().isEqualTo("sendHeaders has already been called"); } @Test public void sendHeader_failsOnClosed() { call.close(Status.CANCELLED, new Metadata()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("call is closed"); - - call.sendHeaders(new Metadata()); + Metadata headers = new Metadata(); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> call.sendHeaders(headers)); + assertThat(e).hasMessageThat().isEqualTo("call is closed"); } @Test @@ -204,18 +203,16 @@ public void sendMessage_failsOnClosed() { call.sendHeaders(new Metadata()); call.close(Status.CANCELLED, new Metadata()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("call is closed"); - - call.sendMessage(1234L); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> call.sendMessage(1234L)); + assertThat(e).hasMessageThat().isEqualTo("call is closed"); } @Test public void sendMessage_failsIfheadersUnsent() { - thrown.expect(IllegalStateException.class); - thrown.expectMessage("sendHeaders has not been called"); - - call.sendMessage(1234L); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> call.sendMessage(1234L)); + assertThat(e).hasMessageThat().isEqualTo("sendHeaders has not been called"); } @Test @@ -490,9 +487,10 @@ public void streamListener_unexpectedRuntimeException() { InputStream inputStream = UNARY_METHOD.streamRequest(1234L); - thrown.expect(RuntimeException.class); - thrown.expectMessage("unexpected exception"); - streamListener.messagesAvailable(new SingleMessageProducer(inputStream)); + SingleMessageProducer producer = new SingleMessageProducer(inputStream); + RuntimeException e = assertThrows(RuntimeException.class, + () -> streamListener.messagesAvailable(producer)); + assertThat(e).hasMessageThat().isEqualTo("unexpected exception"); } private static class LongMarshaller implements Marshaller { diff --git a/core/src/test/java/io/grpc/internal/ServerImplBuilderTest.java b/core/src/test/java/io/grpc/internal/ServerImplBuilderTest.java index 107591038d6..c2cb281a19e 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplBuilderTest.java @@ -18,11 +18,13 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; import io.grpc.InternalConfigurator; import io.grpc.InternalConfiguratorRegistry; import io.grpc.Metadata; +import io.grpc.MetricRecorder; +import io.grpc.MetricSink; +import io.grpc.NoopMetricSink; import io.grpc.ServerBuilder; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; @@ -74,7 +76,8 @@ public void setUp() throws Exception { new ClientTransportServersBuilder() { @Override public InternalServer buildClientTransportServers( - List streamTracerFactories) { + List streamTracerFactories, + MetricRecorder metricRecorder) { throw new UnsupportedOperationException(); } }); @@ -129,6 +132,13 @@ public void getTracerFactories_disableBoth() { assertThat(factories).containsExactly(DUMMY_USER_TRACER); } + @Test + public void addMetricSink_addsToSinks() { + MetricSink noopMetricSink = new NoopMetricSink(); + builder.addMetricSink(noopMetricSink); + assertThat(builder.metricSinks).containsExactly(noopMetricSink); + } + @Test public void getTracerFactories_callsGet() throws Exception { Class runnable = classLoader.loadClass(StaticTestingClassLoaderCallsGet.class.getName()); @@ -140,17 +150,14 @@ public static final class StaticTestingClassLoaderCallsGet implements Runnable { public void run() { ServerImplBuilder builder = new ServerImplBuilder( - streamTracerFactories -> { + (streamTracerFactories, metricRecorder) -> { throw new UnsupportedOperationException(); }); assertThat(builder.getTracerFactories()).hasSize(2); assertThat(builder.interceptors).hasSize(0); - try { - InternalConfiguratorRegistry.setConfigurators(Collections.emptyList()); - fail("exception expected"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("Configurators are already set"); - } + InternalConfiguratorRegistry.setConfigurators(Collections.emptyList()); + assertThat(InternalConfiguratorRegistry.getConfigurators()).isEmpty(); + assertThat(InternalConfiguratorRegistry.getConfiguratorsCallCountBeforeSet()).isEqualTo(1); } } @@ -173,7 +180,7 @@ public void configureServerBuilder(ServerBuilder builder) { })); ServerImplBuilder builder = new ServerImplBuilder( - streamTracerFactories -> { + (streamTracerFactories, metricRecorder) -> { throw new UnsupportedOperationException(); }); assertThat(builder.getTracerFactories()).containsExactly(DUMMY_USER_TRACER); @@ -196,7 +203,7 @@ public void run() { InternalConfiguratorRegistry.setConfigurators(Collections.emptyList()); ServerImplBuilder builder = new ServerImplBuilder( - streamTracerFactories -> { + (streamTracerFactories, metricRecorder) -> { throw new UnsupportedOperationException(); }); assertThat(builder.getTracerFactories()).isEmpty(); diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index 3125edca1e6..3405cb9bb0c 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -26,6 +26,7 @@ import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; @@ -52,6 +53,7 @@ import io.grpc.Channel; import io.grpc.Compressor; import io.grpc.Context; +import io.grpc.Deadline; import io.grpc.Grpc; import io.grpc.HandlerRegistry; import io.grpc.IntegerMarshaller; @@ -63,6 +65,7 @@ import io.grpc.InternalServerInterceptors; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MetricRecorder; import io.grpc.ServerCall; import io.grpc.ServerCall.Listener; import io.grpc.ServerCallExecutorSupplier; @@ -104,7 +107,6 @@ import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -140,8 +142,6 @@ public boolean shouldAccept(Runnable runnable) { }; private static final String AUTHORITY = "some_authority"; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @BeforeClass @@ -207,7 +207,8 @@ public void startUp() throws IOException { new ClientTransportServersBuilder() { @Override public InternalServer buildClientTransportServers( - List streamTracerFactories) { + List streamTracerFactories, + MetricRecorder metricRecorder) { throw new UnsupportedOperationException(); } }); @@ -1148,11 +1149,21 @@ public ServerCall.Listener startCall( @Test public void testContextExpiredBeforeStreamCreate_StreamCancelNotCalledBeforeSetListener() throws Exception { + builder.ticker = new Deadline.Ticker() { + private long time; + + @Override + public long nanoTime() { + time += 1000; + return time; + } + }; + AtomicBoolean contextCancelled = new AtomicBoolean(false); AtomicReference context = new AtomicReference<>(); AtomicReference> callReference = new AtomicReference<>(); - testStreamClose_setup(callReference, context, contextCancelled, 0L); + testStreamClose_setup(callReference, context, contextCancelled, 1L); // This assert that stream.setListener(jumpListener) is called before stream.cancel(), which // prevents extremely short deadlines causing NPEs. @@ -1228,7 +1239,7 @@ public void testStreamClose_deadlineExceededTriggersImmediateCancellation() thro assertFalse(context.get().isCancelled()); assertEquals(1, timer.forwardNanos(1)); - + assertTrue(callReference.get().isCancelled()); assertTrue(context.get().isCancelled()); assertThat(context.get().cancellationCause()).isNotNull(); @@ -1260,9 +1271,8 @@ public List getListenSocketAddresses() { public void getPortBeforeStartedFails() { transportServer = new SimpleServer(); createServer(); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("started"); - server.getPort(); + IllegalStateException e = assertThrows(IllegalStateException.class, () -> server.getPort()); + assertThat(e).hasMessageThat().isEqualTo("Not started"); } @Test @@ -1271,9 +1281,8 @@ public void getPortAfterTerminationFails() throws Exception { createAndStartServer(); server.shutdown(); server.awaitTermination(); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("terminated"); - server.getPort(); + IllegalStateException e = assertThrows(IllegalStateException.class, () -> server.getPort()); + assertThat(e).hasMessageThat().isEqualTo("Already terminated"); } @Test diff --git a/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java b/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java index 6f255763d30..0daee676b82 100644 --- a/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java +++ b/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; +import static io.grpc.internal.UriWrapper.wrap; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; @@ -168,7 +169,7 @@ private void createChannel(ClientInterceptor... interceptors) { new ManagedChannelImpl( channelBuilder, mockTransportFactory, - expectedUri, + wrap(expectedUri), nameResolverProvider, new FakeBackoffPolicyProvider(), balancerRpcExecutorPool, diff --git a/core/src/test/java/io/grpc/internal/SharedResourceHolderTest.java b/core/src/test/java/io/grpc/internal/SharedResourceHolderTest.java index d27195e2490..692b22a0a68 100644 --- a/core/src/test/java/io/grpc/internal/SharedResourceHolderTest.java +++ b/core/src/test/java/io/grpc/internal/SharedResourceHolderTest.java @@ -30,7 +30,9 @@ import io.grpc.internal.SharedResourceHolder.Resource; import java.util.LinkedList; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Delayed; +import java.util.concurrent.FutureTask; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -201,6 +203,46 @@ public void close(ResourceInstance instance) { assertNotSame(instance, holder.getInternal(resource)); } + @Test(timeout = 5000) + public void closeRunsConcurrently() throws Exception { + CyclicBarrier barrier = new CyclicBarrier(2); + class SlowResource implements Resource { + @Override + public ResourceInstance create() { + return new ResourceInstance(); + } + + @Override + public void close(ResourceInstance instance) { + instance.closed = true; + try { + barrier.await(); + barrier.await(); + } catch (Exception ex) { + throw new AssertionError(ex); + } + } + } + + Resource resource = new SlowResource(); + ResourceInstance instance = holder.getInternal(resource); + holder.releaseInternal(resource, instance); + MockScheduledFuture scheduledDestroyTask = scheduledDestroyTasks.poll(); + FutureTask runTask = new FutureTask<>(scheduledDestroyTask::runTask, null); + Thread t = new Thread(runTask); + t.start(); + + barrier.await(); // Ensure the other thread has blocked + assertTrue(instance.closed); + instance = holder.getInternal(resource); + assertFalse(instance.closed); + holder.releaseInternal(resource, instance); + + barrier.await(); // Resume the other thread + t.join(); + runTask.get(); // Check for exception + } + private class MockExecutorFactory implements SharedResourceHolder.ScheduledExecutorFactory { @Override diff --git a/core/src/test/java/io/grpc/internal/SpiffeUtilTest.java b/core/src/test/java/io/grpc/internal/SpiffeUtilTest.java new file mode 100644 index 00000000000..57824cf207f --- /dev/null +++ b/core/src/test/java/io/grpc/internal/SpiffeUtilTest.java @@ -0,0 +1,388 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import com.google.common.base.Optional; +import com.google.common.io.ByteStreams; +import io.grpc.internal.SpiffeUtil.SpiffeBundle; +import io.grpc.internal.SpiffeUtil.SpiffeId; +import io.grpc.testing.TlsTesting; +import io.grpc.util.CertificateUtils; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collection; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + + +@RunWith(Enclosed.class) +public class SpiffeUtilTest { + + @RunWith(Parameterized.class) + public static class ParseSuccessTest { + @Parameter + public String uri; + + @Parameter(1) + public String trustDomain; + + @Parameter(2) + public String path; + + @Test + public void parseSuccessTest() { + SpiffeUtil.SpiffeId spiffeId = SpiffeUtil.parse(uri); + assertEquals(trustDomain, spiffeId.getTrustDomain()); + assertEquals(path, spiffeId.getPath()); + } + + @Parameters(name = "spiffeId={0}") + public static Collection data() { + return Arrays.asList(new String[][] { + {"spiffe://example.com", "example.com", ""}, + {"spiffe://example.com/us", "example.com", "/us"}, + {"spIFfe://qa-staging.final_check.example.com/us", "qa-staging.final_check.example.com", + "/us"}, + {"spiffe://example.com/country/us/state/FL/city/Miami", "example.com", + "/country/us/state/FL/city/Miami"}, + {"SPIFFE://example.com/Czech.Republic/region0.1/city_of-Prague", "example.com", + "/Czech.Republic/region0.1/city_of-Prague"}, + {"spiffe://trust-domain-name/path", "trust-domain-name", "/path"}, + {"spiffe://staging.example.com/payments/mysql", "staging.example.com", "/payments/mysql"}, + {"spiffe://staging.example.com/payments/web-fe", "staging.example.com", + "/payments/web-fe"}, + {"spiffe://k8s-west.example.com/ns/staging/sa/default", "k8s-west.example.com", + "/ns/staging/sa/default"}, + {"spiffe://example.com/9eebccd2-12bf-40a6-b262-65fe0487d453", "example.com", + "/9eebccd2-12bf-40a6-b262-65fe0487d453"}, + {"spiffe://trustdomain/.a..", "trustdomain", "/.a.."}, + {"spiffe://trustdomain/...", "trustdomain", "/..."}, + {"spiffe://trustdomain/abcdefghijklmnopqrstuvwxyz", "trustdomain", + "/abcdefghijklmnopqrstuvwxyz"}, + {"spiffe://trustdomain/abc0123.-_", "trustdomain", "/abc0123.-_"}, + {"spiffe://trustdomain/0123456789", "trustdomain", "/0123456789"}, + {"spiffe://trustdomain0123456789/path", "trustdomain0123456789", "/path"}, + }); + } + } + + @RunWith(Parameterized.class) + public static class ParseFailureTest { + @Parameter + public String uri; + + @Test + public void parseFailureTest() { + assertThrows(IllegalArgumentException.class, () -> SpiffeUtil.parse(uri)); + } + + @Parameters(name = "spiffeId={0}") + public static Collection data() { + return Arrays.asList( + "spiffe:///", + "spiffe://example!com", + "spiffe://exampleя.com/workload-1", + "spiffe://example.com/us/florida/miamiя", + "spiffe:/trustdomain/path", + "spiffe:///path", + "spiffe://trust%20domain/path", + "spiffe://user@trustdomain/path", + "spiffe:// /", + "", + "http://trustdomain/path", + "//trustdomain/path", + "://trustdomain/path", + "piffe://trustdomain/path", + "://", + "://trustdomain", + "spiff", + "spiffe", + "spiffe:////", + "spiffe://trust.domain/../path" + ); + } + } + + public static class ExceptionMessageTest { + + @Test + public void spiffeUriFormatTest() { + NullPointerException npe = assertThrows(NullPointerException.class, () -> + SpiffeUtil.parse(null)); + assertEquals("uri", npe.getMessage()); + + IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("https://example.com")); + assertEquals("Spiffe Id must start with spiffe://", iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com/workload#1")); + assertEquals("Spiffe Id must not contain query fragments", iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com/workload-1?t=1")); + assertEquals("Spiffe Id must not contain query parameters", iae.getMessage()); + } + + @Test + public void spiffeTrustDomainFormatTest() { + IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://")); + assertEquals("Trust Domain can't be empty", iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://eXample.com")); + assertEquals( + "Trust Domain must contain only letters, numbers, dots, dashes, and underscores " + + "([a-z0-9.-_])", + iae.getMessage()); + + StringBuilder longTrustDomain = new StringBuilder("spiffe://pi.eu."); + for (int i = 0; i < 50; i++) { + longTrustDomain.append("pi.eu"); + } + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse(longTrustDomain.toString())); + assertEquals("Trust Domain maximum length is 255 characters", iae.getMessage()); + + @SuppressWarnings("OrphanedFormatString") + StringBuilder longSpiffe = new StringBuilder("spiffe://mydomain%21com/"); + for (int i = 0; i < 405; i++) { + longSpiffe.append("qwert"); + } + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse(longSpiffe.toString())); + assertEquals("Spiffe Id maximum length is 2048 characters", iae.getMessage()); + } + + @Test + public void spiffePathFormatTest() { + IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com//")); + assertEquals("Path must not include a trailing '/'", iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com/")); + assertEquals("Path must not include a trailing '/'", iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com/us//miami")); + assertEquals("Individual path segments must not be empty", iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com/us/.")); + assertEquals("Individual path segments must not be relative path modifiers (i.e. ., ..)", + iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com/us!")); + assertEquals("Individual path segments must contain only letters, numbers, dots, dashes, and " + + "underscores ([a-zA-Z0-9.-_])", iae.getMessage()); + } + } + + public static class CertificateApiTest { + private static final String SPIFFE_PEM_FILE = "spiffe_cert.pem"; + private static final String MULTI_URI_SAN_PEM_FILE = "spiffe_multi_uri_san_cert.pem"; + private static final String SERVER_0_PEM_FILE = "server0.pem"; + private static final String TEST_DIRECTORY_PREFIX = "io/grpc/internal/"; + private static final String SPIFFE_TRUST_BUNDLE = "spiffebundle.json"; + private static final String SPIFFE_TRUST_BUNDLE_WITH_EC_KTY = "spiffebundle_ec.json"; + private static final String SPIFFE_TRUST_BUNDLE_MALFORMED = "spiffebundle_malformed.json"; + private static final String SPIFFE_TRUST_BUNDLE_CORRUPTED_CERT = + "spiffebundle_corrupted_cert.json"; + private static final String SPIFFE_TRUST_BUNDLE_WRONG_KTY = "spiffebundle_wrong_kty.json"; + private static final String SPIFFE_TRUST_BUNDLE_WRONG_KID = "spiffebundle_wrong_kid.json"; + private static final String SPIFFE_TRUST_BUNDLE_WRONG_USE = "spiffebundle_wrong_use.json"; + private static final String SPIFFE_TRUST_BUNDLE_WRONG_MULTI_CERTS = + "spiffebundle_wrong_multi_certs.json"; + private static final String SPIFFE_TRUST_BUNDLE_DUPLICATES = "spiffebundle_duplicates.json"; + private static final String SPIFFE_TRUST_BUNDLE_WRONG_ROOT = "spiffebundle_wrong_root.json"; + private static final String SPIFFE_TRUST_BUNDLE_WRONG_SEQ = "spiffebundle_wrong_seq_type.json"; + private static final String DOMAIN_ERROR_MESSAGE = + " Certificate loading for trust domain 'google.com' failed."; + + + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); + + private X509Certificate[] spiffeCert; + private X509Certificate[] multipleUriSanCert; + private X509Certificate[] serverCert0; + + @Before + public void setUp() throws Exception { + spiffeCert = CertificateUtils.getX509Certificates(TlsTesting.loadCert(SPIFFE_PEM_FILE)); + multipleUriSanCert = CertificateUtils.getX509Certificates(TlsTesting + .loadCert(MULTI_URI_SAN_PEM_FILE)); + serverCert0 = CertificateUtils.getX509Certificates(TlsTesting.loadCert(SERVER_0_PEM_FILE)); + } + + private String copyFileToTmp(String fileName) throws Exception { + File tempFile = tempFolder.newFile(fileName); + try (InputStream resourceStream = SpiffeUtilTest.class.getClassLoader() + .getResourceAsStream(TEST_DIRECTORY_PREFIX + fileName); + OutputStream fileStream = new FileOutputStream(tempFile)) { + ByteStreams.copy(resourceStream, fileStream); + fileStream.flush(); + } + return tempFile.toString(); + } + + @Test + public void extractSpiffeIdSuccessTest() throws Exception { + Optional spiffeId = SpiffeUtil.extractSpiffeId(spiffeCert); + assertTrue(spiffeId.isPresent()); + assertEquals("foo.bar.com", spiffeId.get().getTrustDomain()); + assertEquals("/client/workload/1", spiffeId.get().getPath()); + } + + @Test + public void extractSpiffeIdFailureTest() throws Exception { + Optional spiffeId = SpiffeUtil.extractSpiffeId(serverCert0); + assertFalse(spiffeId.isPresent()); + IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .extractSpiffeId(multipleUriSanCert)); + assertEquals("Multiple URI SAN values found in the leaf cert.", iae.getMessage()); + + } + + @Test + public void extractSpiffeIdFromChainTest() throws Exception { + // Check that the SPIFFE ID is extracted only from the leaf cert in the chain (spiffeCert + // contains it, but serverCert0 does not). + X509Certificate[] leafWithSpiffeChain = new X509Certificate[]{spiffeCert[0], serverCert0[0]}; + assertTrue(SpiffeUtil.extractSpiffeId(leafWithSpiffeChain).isPresent()); + X509Certificate[] leafWithoutSpiffeChain = + new X509Certificate[]{serverCert0[0], spiffeCert[0]}; + assertFalse(SpiffeUtil.extractSpiffeId(leafWithoutSpiffeChain).isPresent()); + } + + @Test + public void extractSpiffeIdParameterValidityTest() { + NullPointerException npe = assertThrows(NullPointerException.class, () -> SpiffeUtil + .extractSpiffeId(null)); + assertEquals("certChain", npe.getMessage()); + IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .extractSpiffeId(new X509Certificate[]{})); + assertEquals("certChain can't be empty", iae.getMessage()); + } + + @Test + public void loadTrustBundleFromFileSuccessTest() throws Exception { + SpiffeBundle tb = SpiffeUtil.loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE)); + assertEquals(2, tb.getSequenceNumbers().size()); + assertEquals(12035488L, (long) tb.getSequenceNumbers().get("example.com")); + assertEquals(-1L, (long) tb.getSequenceNumbers().get("test.example.com")); + assertEquals(3, tb.getBundleMap().size()); + assertEquals(0, tb.getBundleMap().get("test.google.com.au").size()); + assertEquals(1, tb.getBundleMap().get("example.com").size()); + assertEquals(2, tb.getBundleMap().get("test.example.com").size()); + Optional spiffeId = SpiffeUtil.extractSpiffeId(tb.getBundleMap().get("example.com") + .toArray(new X509Certificate[0])); + assertTrue(spiffeId.isPresent()); + assertEquals("foo.bar.com", spiffeId.get().getTrustDomain()); + + SpiffeBundle tb_ec = SpiffeUtil.loadTrustBundleFromFile( + copyFileToTmp(SPIFFE_TRUST_BUNDLE_WITH_EC_KTY)); + assertEquals(2, tb_ec.getSequenceNumbers().size()); + assertEquals(12035488L, (long) tb_ec.getSequenceNumbers().get("example.com")); + assertEquals(-1L, (long) tb_ec.getSequenceNumbers().get("test.example.com")); + assertEquals(3, tb_ec.getBundleMap().size()); + assertEquals(0, tb_ec.getBundleMap().get("test.google.com.au").size()); + assertEquals(1, tb_ec.getBundleMap().get("example.com").size()); + assertEquals(2, tb_ec.getBundleMap().get("test.example.com").size()); + Optional spiffeId_ec = + SpiffeUtil.extractSpiffeId(tb_ec.getBundleMap().get("example.com") + .toArray(new X509Certificate[0])); + assertTrue(spiffeId_ec.isPresent()); + assertEquals("foo.bar.com", spiffeId_ec.get().getTrustDomain()); + } + + @Test + public void loadTrustBundleFromFileFailureTest() { + // Check the exception if JSON root element is different from 'trust_domains' + NullPointerException npe = assertThrows(NullPointerException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_WRONG_ROOT))); + assertEquals("Mandatory trust_domains element is missing", npe.getMessage()); + // Check the exception if JSON root element is different from 'trust_domains' + ClassCastException cce = assertThrows(ClassCastException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_WRONG_SEQ))); + assertTrue(cce.getMessage().contains("Number expected to be long")); + // Check the exception if JSON file doesn't contain an object + IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_MALFORMED))); + assertTrue(iae.getMessage().contains("SPIFFE Trust Bundle should be a JSON object.")); + // Check the exception if JSON contains duplicates + iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_DUPLICATES))); + assertEquals("Duplicate key found: google.com", iae.getMessage()); + // Check the exception if 'x5c' value cannot be parsed + iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_CORRUPTED_CERT))); + assertEquals("Certificate can't be parsed." + DOMAIN_ERROR_MESSAGE, iae.getMessage()); + // Check the exception if 'kty' value differs from 'RSA' + iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_WRONG_KTY))); + assertEquals( + "'kty' parameter must be one of [RSA, EC] but 'null' found." + DOMAIN_ERROR_MESSAGE, + iae.getMessage()); + // Check the exception if 'kid' has a value + iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_WRONG_KID))); + assertEquals("'kid' parameter must not be set." + DOMAIN_ERROR_MESSAGE, iae.getMessage()); + // Check the exception if 'use' value differs from 'x509-svid' + iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_WRONG_USE))); + assertEquals("'use' parameter must be 'x509-svid' but 'i_am_not_x509-svid' found." + + DOMAIN_ERROR_MESSAGE, iae.getMessage()); + // Check the exception if multiple certs are provided for 'x5c' + iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_WRONG_MULTI_CERTS))); + assertEquals("Exactly 1 certificate is expected, but 2 found." + DOMAIN_ERROR_MESSAGE, + iae.getMessage()); + } + + @Test + public void loadTrustBundleFromFileParameterValidityTest() { + NullPointerException npe = assertThrows(NullPointerException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(null)); + assertEquals("trustBundleFile", npe.getMessage()); + FileNotFoundException nsfe = assertThrows(FileNotFoundException.class, () -> SpiffeUtil + .loadTrustBundleFromFile("i_do_not_exist")); + assertTrue( + "Did not contain expected substring: " + nsfe.getMessage(), + nsfe.getMessage().contains("i_do_not_exist")); + } + } +} diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle.json b/core/src/test/resources/io/grpc/internal/spiffebundle.json new file mode 100644 index 00000000000..f968f730d94 --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle.json @@ -0,0 +1,115 @@ +{ + "trust_domains": { + "test.google.com.au": {}, + "example.com": { + "spiffe_sequence": 12035488, + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL + BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL + BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy + NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM + MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu + dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ + LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z + G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO + a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z + JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV + m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 + 7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc + msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc + DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN + zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs + vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI + sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH + HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP + BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t + L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ + aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 + 4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 + IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ + PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV + +j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D + vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq + yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV + z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx + x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U + 0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX + GA91fn0891b5eEW8BJHXX0jri0aN8g=="], + "n": "", + "e": "AQAB" + } + ] + }, + "test.example.com": { + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL + BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL + BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy + NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM + MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu + dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ + LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z + G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO + a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z + JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV + m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 + 7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc + msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc + DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN + zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs + vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI + sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH + HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP + BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t + L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ + aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 + 4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 + IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ + PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV + +j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D + vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq + yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV + z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx + x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U + 0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX + GA91fn0891b5eEW8BJHXX0jri0aN8g=="], + "n": "", + "e": "AQAB" + }, + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIELTCCAxWgAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShEwDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 + MDkxNzE2MTk0NFoXDTM0MDkxNTE2MTk0NFowTjELMAkGA1UEBhMCVVMxCzAJBgNV + BAgMAkNBMQwwCgYDVQQHDANTVkwxDTALBgNVBAoMBGdSUEMxFTATBgNVBAMMDHRl + c3QtY2xpZW50MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOcTjjcS + SfG/EGrr6G+f+3T2GXyHHfroQFi9mZUz80L7uKBdECOImID+YhoK8vcxLQjPmEEv + FIYgJT5amugDcYIgUhMjBx/8RPJaP/nGmBngAqsuuNCaZfyaHBRqN8XdS/AwmsI5 + Wo+nru0+0/7aQFdqqtd2+e9dHjUWwgHxXvMgC4hkHpsdCGIZWVzWyBliwTYQYb1Y + yYe1LzqqQA5OMbZfKOY9MYDCEYOliRiunOn30iIOHj9V5qLzWGfSyxCRuvLRdEP8 + iDeNweHbdaKuI80nQmxuBdRIspE9k5sD1WA4vLZpeg3zggxp4rfLL5zBJgb/33D3 + d9Rkm14xfDPihhkCAwEAAaOB+jCB9zBZBgNVHREEUjBQhiZzcGlmZmU6Ly9mb28u + YmFyLmNvbS9jbGllbnQvd29ya2xvYWQvMYYmc3BpZmZlOi8vZm9vLmJhci5jb20v + Y2xpZW50L3dvcmtsb2FkLzIwHQYDVR0OBBYEFG9GkBgdBg/p0U9/lXv8zIJ+2c2N + MHsGA1UdIwR0MHKhWqRYMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 + YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMM + BnRlc3RjYYIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQELBQADggEB + AJ4Cbxv+02SpUgkEu4hP/1+8DtSBXUxNxI0VG4e3Ap2+Rhjm3YiFeS/UeaZhNrrw + UEjkSTPFODyXR7wI7UO9OO1StyD6CMkp3SEvevU5JsZtGL6mTiTLTi3Qkywa91Bt + GlyZdVMghA1bBJLBMwiD5VT5noqoJBD7hDy6v9yNmt1Sw2iYBJPqI3Gnf5bMjR3s + UICaxmFyqaMCZsPkfJh0DmZpInGJys3m4QqGz6ZE2DWgcSr1r/ML7/5bSPjjr8j4 + WFFSqFR3dMu8CbGnfZTCTXa4GTX/rARXbAO67Z/oJbJBK7VKayskL+PzKuohb9ox + jGL772hQMbwtFCOFXu5VP0s="] + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_corrupted_cert.json b/core/src/test/resources/io/grpc/internal/spiffebundle_corrupted_cert.json new file mode 100644 index 00000000000..9ca51733ff3 --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_corrupted_cert.json @@ -0,0 +1,14 @@ +{ + "trust_domains": { + "google.com": { + "spiffe_sequence": 123, + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["UNPARSABLE_CERTIFICATE"] + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_duplicates.json b/core/src/test/resources/io/grpc/internal/spiffebundle_duplicates.json new file mode 100644 index 00000000000..3f015bd1568 --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_duplicates.json @@ -0,0 +1,23 @@ +{ + "trust_domains": { + "google.com": { + "spiffe_sequence": 123, + "keys": [ + { + "x5c": "VALUE_DOESN'T_MATTER" + } + ] + }, + "google.com": { + "spiffe_sequence": 123, + "keys": [ + { + "use": "x509-svid", + "kid": "some_value", + "x5c": "VALUE_DOESN'T_MATTER" + } + ] + }, + "test.google.com.au": {} + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_ec.json b/core/src/test/resources/io/grpc/internal/spiffebundle_ec.json new file mode 100644 index 00000000000..1732310f8cf --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_ec.json @@ -0,0 +1,116 @@ +{ + "trust_domains": { + "test.google.com.au": {}, + "example.com": { + "spiffe_sequence": 12035488, + "keys": [ + { + + "kty": "EC", + "use": "x509-svid", + "x5c": ["MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL + BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL + BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy + NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM + MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu + dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ + LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z + G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO + a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z + JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV + m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 + 7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc + msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc + DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN + zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs + vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI + sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH + HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP + BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t + L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ + aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 + 4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 + IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ + PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV + +j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D + vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq + yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV + z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx + x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U + 0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX + GA91fn0891b5eEW8BJHXX0jri0aN8g=="], + "n": "", + "e": "AQAB" + } + ] + }, + "test.example.com": { + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL + BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL + BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy + NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM + MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu + dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ + LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z + G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO + a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z + JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV + m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 + 7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc + msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc + DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN + zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs + vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI + sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH + HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP + BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t + L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ + aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 + 4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 + IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ + PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV + +j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D + vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq + yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV + z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx + x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U + 0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX + GA91fn0891b5eEW8BJHXX0jri0aN8g=="], + "n": "", + "e": "AQAB" + }, + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIELTCCAxWgAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShEwDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 + MDkxNzE2MTk0NFoXDTM0MDkxNTE2MTk0NFowTjELMAkGA1UEBhMCVVMxCzAJBgNV + BAgMAkNBMQwwCgYDVQQHDANTVkwxDTALBgNVBAoMBGdSUEMxFTATBgNVBAMMDHRl + c3QtY2xpZW50MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOcTjjcS + SfG/EGrr6G+f+3T2GXyHHfroQFi9mZUz80L7uKBdECOImID+YhoK8vcxLQjPmEEv + FIYgJT5amugDcYIgUhMjBx/8RPJaP/nGmBngAqsuuNCaZfyaHBRqN8XdS/AwmsI5 + Wo+nru0+0/7aQFdqqtd2+e9dHjUWwgHxXvMgC4hkHpsdCGIZWVzWyBliwTYQYb1Y + yYe1LzqqQA5OMbZfKOY9MYDCEYOliRiunOn30iIOHj9V5qLzWGfSyxCRuvLRdEP8 + iDeNweHbdaKuI80nQmxuBdRIspE9k5sD1WA4vLZpeg3zggxp4rfLL5zBJgb/33D3 + d9Rkm14xfDPihhkCAwEAAaOB+jCB9zBZBgNVHREEUjBQhiZzcGlmZmU6Ly9mb28u + YmFyLmNvbS9jbGllbnQvd29ya2xvYWQvMYYmc3BpZmZlOi8vZm9vLmJhci5jb20v + Y2xpZW50L3dvcmtsb2FkLzIwHQYDVR0OBBYEFG9GkBgdBg/p0U9/lXv8zIJ+2c2N + MHsGA1UdIwR0MHKhWqRYMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 + YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMM + BnRlc3RjYYIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQELBQADggEB + AJ4Cbxv+02SpUgkEu4hP/1+8DtSBXUxNxI0VG4e3Ap2+Rhjm3YiFeS/UeaZhNrrw + UEjkSTPFODyXR7wI7UO9OO1StyD6CMkp3SEvevU5JsZtGL6mTiTLTi3Qkywa91Bt + GlyZdVMghA1bBJLBMwiD5VT5noqoJBD7hDy6v9yNmt1Sw2iYBJPqI3Gnf5bMjR3s + UICaxmFyqaMCZsPkfJh0DmZpInGJys3m4QqGz6ZE2DWgcSr1r/ML7/5bSPjjr8j4 + WFFSqFR3dMu8CbGnfZTCTXa4GTX/rARXbAO67Z/oJbJBK7VKayskL+PzKuohb9ox + jGL772hQMbwtFCOFXu5VP0s="] + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_malformed.json b/core/src/test/resources/io/grpc/internal/spiffebundle_malformed.json new file mode 100644 index 00000000000..a2488eeb3cd --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_malformed.json @@ -0,0 +1,4 @@ +[ + "test.google.com", + "test.google.com.au" +] \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_kid.json b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_kid.json new file mode 100644 index 00000000000..f93af634a54 --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_kid.json @@ -0,0 +1,15 @@ +{ + "trust_domains": { + "google.com": { + "spiffe_sequence": 123, + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "kid": "some_value", + "x5c": "VALUE_DOESN'T_MATTER" + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_kty.json b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_kty.json new file mode 100644 index 00000000000..384da03fd6f --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_kty.json @@ -0,0 +1,12 @@ +{ + "trust_domains": { + "google.com": { + "spiffe_sequence": 123, + "keys": [ + { + "x5c": "VALUE_DOESN'T_MATTER" + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_multi_certs.json b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_multi_certs.json new file mode 100644 index 00000000000..5e85635bb02 --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_multi_certs.json @@ -0,0 +1,67 @@ +{ + "trust_domains": { + "google.com": { + "spiffe_sequence": 123, + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL + BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL + BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy + NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM + MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu + dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ + LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z + G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO + a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z + JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV + m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 + 7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc + msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc + DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN + zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs + vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI + sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH + HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP + BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t + L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ + aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 + 4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 + IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ + PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV + +j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D + vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq + yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV + z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx + x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U + 0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX + GA91fn0891b5eEW8BJHXX0jri0aN8g==", + "MIIELTCCAxWgAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShEwDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 + MDkxNzE2MTk0NFoXDTM0MDkxNTE2MTk0NFowTjELMAkGA1UEBhMCVVMxCzAJBgNV + BAgMAkNBMQwwCgYDVQQHDANTVkwxDTALBgNVBAoMBGdSUEMxFTATBgNVBAMMDHRl + c3QtY2xpZW50MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOcTjjcS + SfG/EGrr6G+f+3T2GXyHHfroQFi9mZUz80L7uKBdECOImID+YhoK8vcxLQjPmEEv + FIYgJT5amugDcYIgUhMjBx/8RPJaP/nGmBngAqsuuNCaZfyaHBRqN8XdS/AwmsI5 + Wo+nru0+0/7aQFdqqtd2+e9dHjUWwgHxXvMgC4hkHpsdCGIZWVzWyBliwTYQYb1Y + yYe1LzqqQA5OMbZfKOY9MYDCEYOliRiunOn30iIOHj9V5qLzWGfSyxCRuvLRdEP8 + iDeNweHbdaKuI80nQmxuBdRIspE9k5sD1WA4vLZpeg3zggxp4rfLL5zBJgb/33D3 + d9Rkm14xfDPihhkCAwEAAaOB+jCB9zBZBgNVHREEUjBQhiZzcGlmZmU6Ly9mb28u + YmFyLmNvbS9jbGllbnQvd29ya2xvYWQvMYYmc3BpZmZlOi8vZm9vLmJhci5jb20v + Y2xpZW50L3dvcmtsb2FkLzIwHQYDVR0OBBYEFG9GkBgdBg/p0U9/lXv8zIJ+2c2N + MHsGA1UdIwR0MHKhWqRYMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 + YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMM + BnRlc3RjYYIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQELBQADggEB + AJ4Cbxv+02SpUgkEu4hP/1+8DtSBXUxNxI0VG4e3Ap2+Rhjm3YiFeS/UeaZhNrrw + UEjkSTPFODyXR7wI7UO9OO1StyD6CMkp3SEvevU5JsZtGL6mTiTLTi3Qkywa91Bt + GlyZdVMghA1bBJLBMwiD5VT5noqoJBD7hDy6v9yNmt1Sw2iYBJPqI3Gnf5bMjR3s + UICaxmFyqaMCZsPkfJh0DmZpInGJys3m4QqGz6ZE2DWgcSr1r/ML7/5bSPjjr8j4 + WFFSqFR3dMu8CbGnfZTCTXa4GTX/rARXbAO67Z/oJbJBK7VKayskL+PzKuohb9ox + jGL772hQMbwtFCOFXu5VP0s="] + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_root.json b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_root.json new file mode 100644 index 00000000000..90d2847dc05 --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_root.json @@ -0,0 +1,6 @@ +{ + "trustDomains": { + "test.google.com": {}, + "test.google.com.au": {} + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_seq_type.json b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_seq_type.json new file mode 100644 index 00000000000..4e0aeacc89f --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_seq_type.json @@ -0,0 +1,12 @@ +{ + "trust_domains": { + "google.com": { + "spiffe_sequence": 123.5, + "keys": [ + { + "x5c": "VALUE_DOESN'T_MATTER" + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_use.json b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_use.json new file mode 100644 index 00000000000..166be04846c --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_use.json @@ -0,0 +1,13 @@ +{ + "trust_domains": { + "google.com": { + "keys": [ + { + "kty": "RSA", + "use": "i_am_not_x509-svid", + "x5c": "VALUE_DOESN'T_MATTER" + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java b/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java index 62cbdc4f67b..5d6b88a1392 100644 --- a/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java @@ -24,6 +24,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.junit.Assume.assumeTrue; @@ -57,6 +58,7 @@ import io.grpc.MethodDescriptor; import io.grpc.ServerStreamTracer; import io.grpc.Status; +import io.grpc.internal.MockServerTransportListener.StreamCreation; import io.grpc.internal.testing.TestClientStreamTracer; import io.grpc.internal.testing.TestServerStreamTracer; import java.io.ByteArrayInputStream; @@ -68,17 +70,13 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.junit.After; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -94,7 +92,10 @@ public abstract class AbstractTransportTest { */ public static final int TEST_FLOW_CONTROL_WINDOW = 65 * 1024; - private static final int TIMEOUT_MS = 5000; + protected static final int TIMEOUT_MS = 5000; + + protected static final String GRPC_EXPERIMENTAL_SUPPORT_TRACING_MESSAGE_SIZES = + "GRPC_EXPERIMENTAL_SUPPORT_TRACING_MESSAGE_SIZES"; private static final Attributes.Key ADDITIONAL_TRANSPORT_ATTR_KEY = Attributes.Key.create("additional-attr"); @@ -136,13 +137,6 @@ protected abstract InternalServer newServer( */ protected abstract String testAuthority(InternalServer server); - /** - * Returns true (which is default) if the transport reports message sizes to StreamTracers. - */ - protected boolean sizesReported() { - return true; - } - protected final Attributes eagAttrs() { return EAG_ATTRS; } @@ -163,9 +157,9 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {} * tests in an indeterminate state. */ protected InternalServer server; - private ServerTransport serverTransport; - private ManagedClientTransport client; - private MethodDescriptor methodDescriptor = + protected ServerTransport serverTransport; + protected ManagedClientTransport client; + protected MethodDescriptor methodDescriptor = MethodDescriptor.newBuilder() .setType(MethodDescriptor.MethodType.UNKNOWN) .setFullMethodName("service/method") @@ -182,22 +176,22 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {} "tracer-key", Metadata.ASCII_STRING_MARSHALLER); private final String tracerKeyValue = "tracer-key-value"; - private ManagedClientTransport.Listener mockClientTransportListener + protected ManagedClientTransport.Listener mockClientTransportListener = mock(ManagedClientTransport.Listener.class); - private MockServerListener serverListener = new MockServerListener(); - private ArgumentCaptor throwableCaptor = ArgumentCaptor.forClass(Throwable.class); - private final TestClientStreamTracer clientStreamTracer1 = new TestHeaderClientStreamTracer(); + protected MockServerListener serverListener = new MockServerListener(); + private ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + protected final TestClientStreamTracer clientStreamTracer1 = new TestHeaderClientStreamTracer(); private final TestClientStreamTracer clientStreamTracer2 = new TestHeaderClientStreamTracer(); - private final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + protected final ClientStreamTracer[] tracers = new ClientStreamTracer[] { clientStreamTracer1, clientStreamTracer2 }; private final ClientStreamTracer[] noopTracers = new ClientStreamTracer[] { new ClientStreamTracer() {} }; - private final TestServerStreamTracer serverStreamTracer1 = new TestServerStreamTracer(); + protected final TestServerStreamTracer serverStreamTracer1 = new TestServerStreamTracer(); private final TestServerStreamTracer serverStreamTracer2 = new TestServerStreamTracer(); - private final ServerStreamTracer.Factory serverStreamTracerFactory = mock( + protected final ServerStreamTracer.Factory serverStreamTracerFactory = mock( ServerStreamTracer.Factory.class, delegatesTo(new ServerStreamTracer.Factory() { final ArrayDeque tracers = @@ -213,10 +207,6 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata } })); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Before public void setUp() { server = newServer(Arrays.asList(serverStreamTracerFactory)); @@ -245,6 +235,13 @@ protected void advanceClock(long offset, TimeUnit unit) { throw new UnsupportedOperationException(); } + /** + * Returns true if env var is set. + */ + protected static boolean isEnabledSupportTracingMessageSizes() { + return GrpcUtil.getFlag(GRPC_EXPERIMENTAL_SUPPORT_TRACING_MESSAGE_SIZES, false); + } + /** * Returns the current time, for tests that rely on the clock. */ @@ -266,7 +263,7 @@ protected long fakeCurrentTimeNanos() { // (and maybe exceptions handled) /** - * Test for issue https://github.com/grpc/grpc-java/issues/1682 + * Test for issue https://github.com/grpc/grpc-java/issues/1682 . */ @Test public void frameAfterRstStreamShouldNotBreakClientChannel() throws Exception { @@ -298,8 +295,8 @@ public void frameAfterRstStreamShouldNotBreakClientChannel() throws Exception { serverStreamCreation.stream.flush(); assertEquals( - Status.CANCELLED, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status.CANCELLED, clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); ClientStreamListener mockClientStreamListener2 = mock(ClientStreamListener.class); @@ -329,7 +326,8 @@ public void serverNotListening() throws Exception { runIfNotNull(client.start(mockClientTransportListener)); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - inOrder.verify(mockClientTransportListener).transportShutdown(statusCaptor.capture()); + inOrder.verify(mockClientTransportListener).transportShutdown(statusCaptor.capture(), + any(DisconnectError.class)); assertCodeEquals(Status.UNAVAILABLE, statusCaptor.getValue()); inOrder.verify(mockClientTransportListener).transportTerminated(); verify(mockClientTransportListener, never()).transportReady(); @@ -345,7 +343,8 @@ public void clientStartStop() throws Exception { Status shutdownReason = Status.UNAVAILABLE.withDescription("shutdown called"); client.shutdown(shutdownReason); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); - inOrder.verify(mockClientTransportListener).transportShutdown(same(shutdownReason)); + inOrder.verify(mockClientTransportListener).transportShutdown(same(shutdownReason), + any(DisconnectError.class)); inOrder.verify(mockClientTransportListener).transportTerminated(); verify(mockClientTransportListener, never()).transportInUse(anyBoolean()); } @@ -361,7 +360,8 @@ public void clientStartAndStopOnceConnected() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); client.shutdown(Status.UNAVAILABLE); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); - inOrder.verify(mockClientTransportListener).transportShutdown(any(Status.class)); + inOrder.verify(mockClientTransportListener).transportShutdown(any(Status.class), + any(DisconnectError.class)); inOrder.verify(mockClientTransportListener).transportTerminated(); assertTrue(serverTransportListener.waitForTermination(TIMEOUT_MS, TimeUnit.MILLISECONDS)); server.shutdown(); @@ -393,8 +393,7 @@ public void serverAlreadyListening() throws Exception { port = ((InetSocketAddress) addr).getPort(); } InternalServer server2 = newServer(port, Arrays.asList(serverStreamTracerFactory)); - thrown.expect(IOException.class); - server2.start(new MockServerListener()); + assertThrows(IOException.class, () -> server2.start(new MockServerListener())); } @Test @@ -458,7 +457,8 @@ public void openStreamPreventsTermination() throws Exception { serverTransport.shutdown(); serverTransport = null; - verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class), + any(DisconnectError.class)); assertTrue(serverListener.waitForShutdown(TIMEOUT_MS, TimeUnit.MILLISECONDS)); // A new server should be able to start listening, since the current server has given up @@ -472,7 +472,7 @@ public void openStreamPreventsTermination() throws Exception { // the stream still functions. serverStream.writeHeaders(new Metadata(), true); clientStream.halfClose(); - assertNotNull(clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitHeaders(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); verify(mockClientTransportListener, never()).transportTerminated(); @@ -508,15 +508,16 @@ public void shutdownNowKillsClientStream() throws Exception { client.shutdownNow(status); client = null; - verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class), + any(DisconnectError.class)); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(false); assertTrue(serverTransportListener.waitForTermination(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(serverTransportListener.isTerminated()); - assertEquals(status, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status serverStatus = serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertEquals(status, clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status serverStatus = serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertFalse(serverStatus.isOk()); assertTrue(clientStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertNull(clientStreamTracer1.getInboundTrailers()); @@ -547,15 +548,16 @@ public void shutdownNowKillsServerStream() throws Exception { serverTransport.shutdownNow(shutdownStatus); serverTransport = null; - verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class), + any(DisconnectError.class)); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(false); assertTrue(serverTransportListener.waitForTermination(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(serverTransportListener.isTerminated()); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertFalse(clientStreamStatus.isOk()); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(clientStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertNull(clientStreamTracer1.getInboundTrailers()); assertStatusEquals(clientStreamStatus, clientStreamTracer1.getStatus()); @@ -565,7 +567,7 @@ public void shutdownNowKillsServerStream() throws Exception { // Generally will be same status provided to shutdownNow, but InProcessTransport can't // differentiate between client and server shutdownNow. The status is not really used on // server-side, so we don't care much. - assertNotNull(serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); } @Test @@ -595,7 +597,8 @@ public void ping_duringShutdown() throws Exception { ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); client.shutdown(Status.UNAVAILABLE); - verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class), + any(DisconnectError.class)); ClientTransport.PingCallback mockPingCallback = mock(ClientTransport.PingCallback.class); try { client.ping(mockPingCallback, MoreExecutors.directExecutor()); @@ -623,8 +626,8 @@ public void ping_afterTermination() throws Exception { // Transport doesn't support ping, so this neither passes nor fails. assumeTrue(false); } - verify(mockPingCallback, timeout(TIMEOUT_MS)).onFailure(throwableCaptor.capture()); - Status status = Status.fromThrowable(throwableCaptor.getValue()); + verify(mockPingCallback, timeout(TIMEOUT_MS)).onFailure(statusCaptor.capture()); + Status status = statusCaptor.getValue(); assertSame(shutdownReason, status); } @@ -639,15 +642,16 @@ public void newStream_duringShutdown() throws Exception { ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); client.shutdown(Status.UNAVAILABLE); - verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class), + any(DisconnectError.class)); ClientStream stream2 = client.newStream( methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); Status clientStreamStatus2 = - clientStreamListener2.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); - assertNotNull(clientStreamListener2.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + clientStreamListener2.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener2.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertCodeEquals(Status.UNAVAILABLE, clientStreamStatus2); assertNull(clientStreamTracer2.getInboundTrailers()); assertSame(clientStreamStatus2, clientStreamTracer2.getStatus()); @@ -661,8 +665,8 @@ public void newStream_duringShutdown() throws Exception { StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(20 * TIMEOUT_MS, TimeUnit.MILLISECONDS); serverStreamCreation.stream.close(Status.OK, new Metadata()); - assertCodeEquals(Status.OK, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertCodeEquals(Status.OK, clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); } @Test @@ -682,8 +686,8 @@ public void newStream_afterTermination() throws Exception { ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); assertEquals( - shutdownReason, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + shutdownReason, clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); verify(mockClientTransportListener, never()).transportInUse(anyBoolean()); assertNull(clientStreamTracer1.getInboundTrailers()); assertSame(shutdownReason, clientStreamTracer1.getStatus()); @@ -791,6 +795,17 @@ public void transportInUse_clientCancel() throws Exception { @Test public void basicStream() throws Exception { + serverListener = + new MockServerListener( + transport -> + new MockServerTransportListener(transport) { + @Override + public Attributes transportReady(Attributes attributes) { + return super.transportReady(attributes).toBuilder() + .set(ADDITIONAL_TRANSPORT_ATTR_KEY, "additional attribute value") + .build(); + } + }); InOrder serverInOrder = inOrder(serverStreamTracerFactory); server.start(serverListener); client = newClientTransport(server); @@ -857,25 +872,20 @@ public void basicStream() throws Exception { message.close(); assertThat(clientStreamTracer1.nextOutboundEvent()) .matches("outboundMessageSent\\(0, -?[0-9]+, -?[0-9]+\\)"); - if (sizesReported()) { + if (isEnabledSupportTracingMessageSizes()) { assertThat(clientStreamTracer1.getOutboundWireSize()).isGreaterThan(0L); assertThat(clientStreamTracer1.getOutboundUncompressedSize()).isGreaterThan(0L); - } else { - assertThat(clientStreamTracer1.getOutboundWireSize()).isEqualTo(0L); - assertThat(clientStreamTracer1.getOutboundUncompressedSize()).isEqualTo(0L); } + assertThat(serverStreamTracer1.nextInboundEvent()).isEqualTo("inboundMessage(0)"); assertNull("no additional message expected", serverStreamListener.messageQueue.poll()); clientStream.halfClose(); assertTrue(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - if (sizesReported()) { + if (isEnabledSupportTracingMessageSizes()) { assertThat(serverStreamTracer1.getInboundWireSize()).isGreaterThan(0L); assertThat(serverStreamTracer1.getInboundUncompressedSize()).isGreaterThan(0L); - } else { - assertThat(serverStreamTracer1.getInboundWireSize()).isEqualTo(0L); - assertThat(serverStreamTracer1.getInboundUncompressedSize()).isEqualTo(0L); } assertThat(serverStreamTracer1.nextInboundEvent()) .matches("inboundMessageRead\\(0, -?[0-9]+, -?[0-9]+\\)"); @@ -889,7 +899,7 @@ public void basicStream() throws Exception { Metadata serverHeadersCopy = new Metadata(); serverHeadersCopy.merge(serverHeaders); serverStream.writeHeaders(serverHeaders, true); - Metadata headers = clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Metadata headers = clientStreamListener.awaitHeaders(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotNull(headers); assertAsciiMetadataValuesEqual(serverHeadersCopy.getAll(asciiKey), headers.getAll(asciiKey)); assertEquals( @@ -907,24 +917,18 @@ public void basicStream() throws Exception { assertNotNull("message expected", message); assertThat(serverStreamTracer1.nextOutboundEvent()) .matches("outboundMessageSent\\(0, -?[0-9]+, -?[0-9]+\\)"); - if (sizesReported()) { + if (isEnabledSupportTracingMessageSizes()) { assertThat(serverStreamTracer1.getOutboundWireSize()).isGreaterThan(0L); assertThat(serverStreamTracer1.getOutboundUncompressedSize()).isGreaterThan(0L); - } else { - assertThat(serverStreamTracer1.getOutboundWireSize()).isEqualTo(0L); - assertThat(serverStreamTracer1.getOutboundUncompressedSize()).isEqualTo(0L); } assertTrue(clientStreamTracer1.getInboundHeaders()); assertThat(clientStreamTracer1.nextInboundEvent()).isEqualTo("inboundMessage(0)"); assertEquals("Hi. Who are you?", methodDescriptor.parseResponse(message)); assertThat(clientStreamTracer1.nextInboundEvent()) .matches("inboundMessageRead\\(0, -?[0-9]+, -?[0-9]+\\)"); - if (sizesReported()) { + if (isEnabledSupportTracingMessageSizes()) { assertThat(clientStreamTracer1.getInboundWireSize()).isGreaterThan(0L); assertThat(clientStreamTracer1.getInboundUncompressedSize()).isGreaterThan(0L); - } else { - assertThat(clientStreamTracer1.getInboundWireSize()).isEqualTo(0L); - assertThat(clientStreamTracer1.getInboundUncompressedSize()).isEqualTo(0L); } message.close(); @@ -940,11 +944,11 @@ public void basicStream() throws Exception { serverStream.close(status, trailers); assertNull(serverStreamTracer1.nextInboundEvent()); assertNull(serverStreamTracer1.nextOutboundEvent()); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertSame(status, serverStreamTracer1.getStatus()); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientStreamTrailers = - clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertSame(clientStreamTrailers, clientStreamTracer1.getInboundTrailers()); assertSame(clientStreamStatus, clientStreamTracer1.getStatus()); assertNull(clientStreamTracer1.nextInboundEvent()); @@ -1013,14 +1017,14 @@ public void zeroMessageStream() throws Exception { assertTrue(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); serverStream.writeHeaders(new Metadata(), true); - assertNotNull(clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitHeaders(TIMEOUT_MS, TimeUnit.MILLISECONDS)); Status status = Status.OK.withDescription("Nice talking to you"); serverStream.close(status, new Metadata()); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientStreamTrailers = - clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotNull(clientStreamTrailers); assertEquals(status.getCode(), clientStreamStatus.getCode()); assertEquals(status.getDescription(), clientStreamStatus.getDescription()); @@ -1050,15 +1054,15 @@ public void earlyServerClose_withServerHeaders() throws Exception { ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; serverStream.writeHeaders(new Metadata(), true); - assertNotNull(clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitHeaders(TIMEOUT_MS, TimeUnit.MILLISECONDS)); Status strippedStatus = Status.OK.withDescription("Hello. Goodbye."); Status status = strippedStatus.withCause(new Exception()); serverStream.close(status, new Metadata()); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientStreamTrailers = - clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotNull(clientStreamTrailers); checkClientStatus(status, clientStreamStatus); assertTrue(clientStreamTracer1.getOutboundHeaders()); @@ -1094,10 +1098,10 @@ public void earlyServerClose_noServerHeaders() throws Exception { trailers.put(asciiKey, "dupvalue"); trailers.put(binaryKey, "äbinarytrailers"); serverStream.close(status, trailers); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientStreamTrailers = - clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); checkClientStatus(status, clientStreamStatus); assertEquals( Lists.newArrayList(trailers.getAll(asciiKey)), @@ -1132,10 +1136,10 @@ public void earlyServerClose_serverFailure() throws Exception { Status strippedStatus = Status.INTERNAL.withDescription("I'm not listening"); Status status = strippedStatus.withCause(new Exception()); serverStream.close(status, new Metadata()); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientStreamTrailers = - clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotNull(clientStreamTrailers); checkClientStatus(status, clientStreamStatus); assertTrue(clientStreamTracer1.getOutboundHeaders()); @@ -1175,10 +1179,10 @@ public void closed( Status strippedStatus = Status.INTERNAL.withDescription("I'm not listening"); Status status = strippedStatus.withCause(new Exception()); serverStream.close(status, new Metadata()); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientStreamTrailers = - clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotNull(clientStreamTrailers); checkClientStatus(status, clientStreamStatus); assertTrue(clientStreamTracer1.getOutboundHeaders()); @@ -1206,9 +1210,9 @@ public void clientCancel() throws Exception { Status status = Status.CANCELLED.withDescription("Nevermind").withCause(new Exception()); clientStream.cancel(status); - assertEquals(status, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status serverStatus = serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertEquals(status, clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status serverStatus = serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotEquals(Status.Code.OK, serverStatus.getCode()); // Cause should not be transmitted between client and server by default assertNull(serverStatus.getCause()); @@ -1285,16 +1289,11 @@ public void onReady() { serverStream.close(Status.OK, new Metadata()); assertTrue(clientStreamTracer1.getOutboundHeaders()); assertTrue(clientStreamTracer1.getInboundHeaders()); - if (sizesReported()) { + if (isEnabledSupportTracingMessageSizes()) { assertThat(clientStreamTracer1.getInboundWireSize()).isGreaterThan(0L); assertThat(clientStreamTracer1.getInboundUncompressedSize()).isGreaterThan(0L); assertThat(serverStreamTracer1.getOutboundWireSize()).isGreaterThan(0L); assertThat(serverStreamTracer1.getOutboundUncompressedSize()).isGreaterThan(0L); - } else { - assertThat(clientStreamTracer1.getInboundWireSize()).isEqualTo(0L); - assertThat(clientStreamTracer1.getInboundUncompressedSize()).isEqualTo(0L); - assertThat(serverStreamTracer1.getOutboundWireSize()).isEqualTo(0L); - assertThat(serverStreamTracer1.getOutboundUncompressedSize()).isEqualTo(0L); } assertNull(clientStreamTracer1.getInboundTrailers()); assertSame(status, clientStreamTracer1.getStatus()); @@ -1325,9 +1324,9 @@ public void serverCancel() throws Exception { Status status = Status.DEADLINE_EXCEEDED.withDescription("It was bound to happen") .withCause(new Exception()); serverStream.cancel(status); - assertEquals(status, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertEquals(status, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); // Presently we can't sent much back to the client in this case. Verify that is the current // behavior for consistency between transports. assertCodeEquals(Status.CANCELLED, clientStreamStatus); @@ -1458,7 +1457,7 @@ public void flowControlPushBack() throws Exception { clientStream.flush(); clientStream.halfClose(); doPingPong(serverListener); - assertFalse(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertFalse(serverStreamListener.isHalfClosed()); serverStream.request(1); serverReceived += verifyMessageCountAndClose(serverStreamListener.messageQueue, 1); @@ -1470,18 +1469,14 @@ public void flowControlPushBack() throws Exception { Status status = Status.OK.withDescription("... quite a lengthy discussion"); serverStream.close(status, new Metadata()); doPingPong(serverListener); - try { - clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); - fail("Expected TimeoutException"); - } catch (TimeoutException expectedException) { - } + assertFalse(clientStreamListener.isClosed()); clientStream.request(1); clientReceived += verifyMessageCountAndClose(clientStreamListener.messageQueue, 1); assertEquals(serverSent + 6, clientReceived); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertEquals(status.getCode(), clientStreamStatus.getCode()); assertEquals(status.getDescription(), clientStreamStatus.getDescription()); } @@ -1537,9 +1532,9 @@ public void flowControlDoesNotDeadlockLargeMessage() throws Exception { serverStream.close(status, new Metadata()); doPingPong(serverListener); clientStream.request(1); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertEquals(status.getCode(), clientStreamStatus.getCode()); assertEquals(status.getDescription(), clientStreamStatus.getDescription()); } @@ -1607,8 +1602,8 @@ public void interactionsAfterServerStreamCloseAreNoops() throws Exception { // setup clientStream.request(1); server.stream.close(Status.INTERNAL, new Metadata()); - assertNotNull(clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); // Ensure that for a closed ServerStream, interactions are noops server.stream.writeHeaders(new Metadata(), true); @@ -1640,7 +1635,7 @@ public void interactionsAfterClientStreamCancelAreNoops() throws Exception { // setup server.stream.request(1); clientStream.cancel(Status.UNKNOWN); - assertNotNull(server.listener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(server.listener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); // Ensure that for a cancelled ClientStream, interactions are noops clientStream.writeMessage(methodDescriptor.streamRequest("request")); @@ -1763,9 +1758,8 @@ public void transportTracer_server_streamEnded_ok() throws Exception { clientStream.halfClose(); serverStream.close(Status.OK, new Metadata()); // do not validate stats until close() has been called on client - assertNotNull(clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - + assertNotNull(clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); TransportStats serverAfter = getTransportStats(serverTransportListener.transport); assertEquals(1, serverAfter.streamsSucceeded); @@ -1802,9 +1796,8 @@ public void transportTracer_server_streamEnded_nonOk() throws Exception { serverStream.close(Status.UNKNOWN, new Metadata()); // do not validate stats until close() has been called on client - assertNotNull(clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - + assertNotNull(clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); TransportStats serverAfter = getTransportStats(serverTransportListener.transport); assertEquals(1, serverAfter.streamsFailed); @@ -1842,7 +1835,7 @@ public void transportTracer_client_streamEnded_nonOk() throws Exception { clientStream.cancel(Status.UNKNOWN); // do not validate stats until close() has been called on server - assertNotNull(serverStreamCreation.listener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(serverStreamCreation.listener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); TransportStats serverAfter = getTransportStats(serverTransportListener.transport); assertEquals(1, serverAfter.streamsFailed); @@ -1999,7 +1992,7 @@ public void serverChecksInboundMetadataSize() throws Exception { // Server shouldn't have created a stream, so nothing to clean up on server-side // If this times out, the server probably isn't noticing the metadata size - Status status = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Status status = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); List codeOptions = Arrays.asList( Status.Code.UNKNOWN, Status.Code.RESOURCE_EXHAUSTED, Status.Code.INTERNAL); if (!codeOptions.contains(status.getCode())) { @@ -2040,13 +2033,13 @@ public void clientChecksInboundMetadataSize_header() throws Exception { serverStreamCreation.stream.writeMessage(methodDescriptor.streamResponse("response")); serverStreamCreation.stream.close(Status.OK, new Metadata()); - Status status = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Status status = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); List codeOptions = Arrays.asList( Status.Code.UNKNOWN, Status.Code.RESOURCE_EXHAUSTED, Status.Code.INTERNAL); if (!codeOptions.contains(status.getCode())) { fail("Status code was not expected: " + status); } - assertFalse(clientStreamListener.headers.isDone()); + assertFalse(clientStreamListener.hasHeaders()); } /** This assumes the client limits metadata size to GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE. */ @@ -2085,13 +2078,13 @@ public void clientChecksInboundMetadataSize_trailer() throws Exception { serverStreamCreation.stream.writeMessage(methodDescriptor.streamResponse("response")); serverStreamCreation.stream.close(Status.OK, tooLargeMetadata); - Status status = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Status status = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); List codeOptions = Arrays.asList( Status.Code.UNKNOWN, Status.Code.RESOURCE_EXHAUSTED, Status.Code.INTERNAL); if (!codeOptions.contains(status.getCode())) { fail("Status code was not expected: " + status); } - Metadata metadata = clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Metadata metadata = clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNull(metadata.get(tellTaleKey)); } @@ -2119,9 +2112,9 @@ methodDescriptor, new Metadata(), callOptions, ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; serverStream.close(Status.OK, new Metadata()); - assertNotNull(clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); client.shutdown(Status.UNAVAILABLE); } @@ -2166,7 +2159,7 @@ private static void checkClientStatus(Status expectedStatus, Status clientStream assertNull(clientStreamStatus.getCause()); } - private static boolean waitForFuture(Future future, long timeout, TimeUnit unit) + static boolean waitForFuture(Future future, long timeout, TimeUnit unit) throws InterruptedException { try { future.get(timeout, unit); @@ -2178,13 +2171,13 @@ private static boolean waitForFuture(Future future, long timeout, TimeUnit un return true; } - private static void runIfNotNull(Runnable runnable) { + protected static void runIfNotNull(Runnable runnable) { if (runnable != null) { runnable.run(); } } - private static void startTransport( + protected static void startTransport( ManagedClientTransport clientTransport, ManagedClientTransport.Listener listener) { runIfNotNull(clientTransport.start(listener)); @@ -2202,218 +2195,6 @@ public void streamCreated(Attributes transportAttrs, Metadata metadata) { } } - private static class MockServerListener implements ServerListener { - public final BlockingQueue listeners - = new LinkedBlockingQueue<>(); - private final SettableFuture shutdown = SettableFuture.create(); - - @Override - public ServerTransportListener transportCreated(ServerTransport transport) { - MockServerTransportListener listener = new MockServerTransportListener(transport); - listeners.add(listener); - return listener; - } - - @Override - public void serverShutdown() { - assertTrue(shutdown.set(null)); - } - - public boolean waitForShutdown(long timeout, TimeUnit unit) throws InterruptedException { - return waitForFuture(shutdown, timeout, unit); - } - - public MockServerTransportListener takeListenerOrFail(long timeout, TimeUnit unit) - throws InterruptedException { - MockServerTransportListener listener = listeners.poll(timeout, unit); - if (listener == null) { - fail("Timed out waiting for server transport"); - } - return listener; - } - } - - private static class MockServerTransportListener implements ServerTransportListener { - public final ServerTransport transport; - public final BlockingQueue streams = new LinkedBlockingQueue<>(); - private final SettableFuture terminated = SettableFuture.create(); - - public MockServerTransportListener(ServerTransport transport) { - this.transport = transport; - } - - @Override - public void streamCreated(ServerStream stream, String method, Metadata headers) { - ServerStreamListenerBase listener = new ServerStreamListenerBase(); - streams.add(new StreamCreation(stream, method, headers, listener)); - stream.setListener(listener); - } - - @Override - public Attributes transportReady(Attributes attributes) { - assertFalse(terminated.isDone()); - return Attributes.newBuilder() - .setAll(attributes) - .set(ADDITIONAL_TRANSPORT_ATTR_KEY, "additional attribute value") - .build(); - } - - @Override - public void transportTerminated() { - assertTrue(terminated.set(null)); - } - - public boolean waitForTermination(long timeout, TimeUnit unit) throws InterruptedException { - return waitForFuture(terminated, timeout, unit); - } - - public boolean isTerminated() { - return terminated.isDone(); - } - - public StreamCreation takeStreamOrFail(long timeout, TimeUnit unit) - throws InterruptedException { - StreamCreation stream = streams.poll(timeout, unit); - if (stream == null) { - fail("Timed out waiting for server stream"); - } - return stream; - } - } - - private static class ServerStreamListenerBase implements ServerStreamListener { - private final BlockingQueue messageQueue = new LinkedBlockingQueue<>(); - // Would have used Void instead of Object, but null elements are not allowed - private final BlockingQueue readyQueue = new LinkedBlockingQueue<>(); - private final CountDownLatch halfClosedLatch = new CountDownLatch(1); - private final SettableFuture status = SettableFuture.create(); - - private boolean awaitOnReady(int timeout, TimeUnit unit) throws Exception { - return readyQueue.poll(timeout, unit) != null; - } - - private boolean awaitOnReadyAndDrain(int timeout, TimeUnit unit) throws Exception { - if (!awaitOnReady(timeout, unit)) { - return false; - } - // Throw the rest away - readyQueue.drainTo(Lists.newArrayList()); - return true; - } - - private boolean awaitHalfClosed(int timeout, TimeUnit unit) throws Exception { - return halfClosedLatch.await(timeout, unit); - } - - @Override - public void messagesAvailable(MessageProducer producer) { - if (status.isDone()) { - fail("messagesAvailable invoked after closed"); - } - InputStream message; - while ((message = producer.next()) != null) { - messageQueue.add(message); - } - } - - @Override - public void onReady() { - if (status.isDone()) { - fail("onReady invoked after closed"); - } - readyQueue.add(new Object()); - } - - @Override - public void halfClosed() { - if (status.isDone()) { - fail("halfClosed invoked after closed"); - } - halfClosedLatch.countDown(); - } - - @Override - public void closed(Status status) { - if (this.status.isDone()) { - fail("closed invoked more than once"); - } - this.status.set(status); - } - } - - private static class ClientStreamListenerBase implements ClientStreamListener { - private final BlockingQueue messageQueue = new LinkedBlockingQueue<>(); - // Would have used Void instead of Object, but null elements are not allowed - private final BlockingQueue readyQueue = new LinkedBlockingQueue<>(); - private final SettableFuture headers = SettableFuture.create(); - private final SettableFuture trailers = SettableFuture.create(); - private final SettableFuture status = SettableFuture.create(); - - private boolean awaitOnReady(int timeout, TimeUnit unit) throws Exception { - return readyQueue.poll(timeout, unit) != null; - } - - private boolean awaitOnReadyAndDrain(int timeout, TimeUnit unit) throws Exception { - if (!awaitOnReady(timeout, unit)) { - return false; - } - // Throw the rest away - readyQueue.drainTo(Lists.newArrayList()); - return true; - } - - @Override - public void messagesAvailable(MessageProducer producer) { - if (status.isDone()) { - fail("messagesAvailable invoked after closed"); - } - InputStream message; - while ((message = producer.next()) != null) { - messageQueue.add(message); - } - } - - @Override - public void onReady() { - if (status.isDone()) { - fail("onReady invoked after closed"); - } - readyQueue.add(new Object()); - } - - @Override - public void headersRead(Metadata headers) { - if (status.isDone()) { - fail("headersRead invoked after closed"); - } - this.headers.set(headers); - } - - @Override - public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { - if (this.status.isDone()) { - fail("headersRead invoked after closed"); - } - this.status.set(status); - this.trailers.set(trailers); - } - } - - private static class StreamCreation { - public final ServerStream stream; - public final String method; - public final Metadata headers; - public final ServerStreamListenerBase listener; - - public StreamCreation( - ServerStream stream, String method, Metadata headers, ServerStreamListenerBase listener) { - this.stream = stream; - this.method = method; - this.headers = headers; - this.listener = listener; - } - } - private static class StringMarshaller implements MethodDescriptor.Marshaller { public static final StringMarshaller INSTANCE = new StringMarshaller(); diff --git a/core/src/testFixtures/java/io/grpc/internal/ClientStreamListenerBase.java b/core/src/testFixtures/java/io/grpc/internal/ClientStreamListenerBase.java new file mode 100644 index 00000000000..3c35cf59225 --- /dev/null +++ b/core/src/testFixtures/java/io/grpc/internal/ClientStreamListenerBase.java @@ -0,0 +1,126 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.fail; + +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Metadata; +import io.grpc.Status; +import java.io.InputStream; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +public class ClientStreamListenerBase implements ClientStreamListener { + public final BlockingQueue messageQueue = new LinkedBlockingQueue<>(); + // Would have used Void instead of Object, but null elements are not allowed + private final BlockingQueue readyQueue = new LinkedBlockingQueue<>(); + private final SettableFuture headers = SettableFuture.create(); + private final SettableFuture trailers = SettableFuture.create(); + private final SettableFuture status = SettableFuture.create(); + + /** + * Returns the stream's status or throws {@link java.util.concurrent.TimeoutException} if it isn't + * closed before the timeout. + */ + public Status awaitClose(int timeout, TimeUnit unit) throws Exception { + return status.get(timeout, unit); + } + + /** + * Return {@code true} if {@code #awaitClose} would return immediately with a status. + */ + public boolean isClosed() { + return status.isDone(); + } + + /** + * Returns response headers from the server or throws {@link + * java.util.concurrent.TimeoutException} if they aren't delivered before the timeout. + * + *

Callers must not modify the returned object. + */ + public Metadata awaitHeaders(int timeout, TimeUnit unit) throws Exception { + return headers.get(timeout, unit); + } + + /** + * Returns response trailers from the server or throws {@link + * java.util.concurrent.TimeoutException} if they aren't delivered before the timeout. + * + *

Callers must not modify the returned object. + */ + public Metadata awaitTrailers(int timeout, TimeUnit unit) throws Exception { + return trailers.get(timeout, unit); + } + + public boolean awaitOnReady(int timeout, TimeUnit unit) throws Exception { + return readyQueue.poll(timeout, unit) != null; + } + + public boolean awaitOnReadyAndDrain(int timeout, TimeUnit unit) throws Exception { + if (!awaitOnReady(timeout, unit)) { + return false; + } + // Throw the rest away + readyQueue.drainTo(Lists.newArrayList()); + return true; + } + + @Override + public void messagesAvailable(MessageProducer producer) { + if (status.isDone()) { + fail("messagesAvailable invoked after closed"); + } + InputStream message; + while ((message = producer.next()) != null) { + messageQueue.add(message); + } + } + + @Override + public void onReady() { + if (status.isDone()) { + fail("onReady invoked after closed"); + } + readyQueue.add(new Object()); + } + + @Override + public void headersRead(Metadata headers) { + if (status.isDone()) { + fail("headersRead invoked after closed"); + } + this.headers.set(headers); + } + + @Override + public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { + if (this.status.isDone()) { + fail("headersRead invoked after closed"); + } + this.status.set(status); + this.trailers.set(trailers); + } + + /** Returns true iff response headers have been received from the server. */ + public boolean hasHeaders() { + return headers.isDone(); + } +} diff --git a/core/src/testFixtures/java/io/grpc/internal/MockServerListener.java b/core/src/testFixtures/java/io/grpc/internal/MockServerListener.java new file mode 100644 index 00000000000..0c33b98cf1c --- /dev/null +++ b/core/src/testFixtures/java/io/grpc/internal/MockServerListener.java @@ -0,0 +1,78 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.common.util.concurrent.SettableFuture; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * A {@link ServerListener} that helps you write blocking unit tests. + * + *

TODO: Rename, since this is not actually a mock: + * https://testing.googleblog.com/2013/07/testing-on-toilet-know-your-test-doubles.html + */ +public class MockServerListener implements ServerListener { + private final BlockingQueue listeners = new LinkedBlockingQueue<>(); + private final SettableFuture shutdown = SettableFuture.create(); + private final ServerTransportListenerFactory serverTransportListenerFactory; + + /** + * Lets you customize the {@link MockServerTransportListener} installed on newly created + * {@link ServerTransport}s. + */ + public interface ServerTransportListenerFactory { + MockServerTransportListener create(ServerTransport transport); + } + + public MockServerListener(ServerTransportListenerFactory serverTransportListenerFactory) { + this.serverTransportListenerFactory = serverTransportListenerFactory; + } + + public MockServerListener() { + this(MockServerTransportListener::new); + } + + @Override + public ServerTransportListener transportCreated(ServerTransport transport) { + MockServerTransportListener listener = serverTransportListenerFactory.create(transport); + listeners.add(listener); + return listener; + } + + @Override + public void serverShutdown() { + assertTrue(shutdown.set(null)); + } + + public boolean waitForShutdown(long timeout, TimeUnit unit) throws InterruptedException { + return AbstractTransportTest.waitForFuture(shutdown, timeout, unit); + } + + public MockServerTransportListener takeListenerOrFail(long timeout, TimeUnit unit) + throws InterruptedException { + MockServerTransportListener listener = listeners.poll(timeout, unit); + if (listener == null) { + fail("Timed out waiting for server transport"); + } + return listener; + } +} diff --git a/core/src/testFixtures/java/io/grpc/internal/MockServerTransportListener.java b/core/src/testFixtures/java/io/grpc/internal/MockServerTransportListener.java new file mode 100644 index 00000000000..e6c4e2f578e --- /dev/null +++ b/core/src/testFixtures/java/io/grpc/internal/MockServerTransportListener.java @@ -0,0 +1,93 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Attributes; +import io.grpc.Metadata; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * A {@link ServerTransportListener} that helps you write blocking unit tests. + * + *

TODO: Rename, since this is not actually a mock: + * https://testing.googleblog.com/2013/07/testing-on-toilet-know-your-test-doubles.html + */ +public class MockServerTransportListener implements ServerTransportListener { + public final ServerTransport transport; + private final BlockingQueue streams = new LinkedBlockingQueue<>(); + private final SettableFuture terminated = SettableFuture.create(); + + public MockServerTransportListener(ServerTransport transport) { + this.transport = transport; + } + + @Override + public void streamCreated(ServerStream stream, String method, Metadata headers) { + ServerStreamListenerBase listener = new ServerStreamListenerBase(); + streams.add(new StreamCreation(stream, method, headers, listener)); + stream.setListener(listener); + } + + @Override + public Attributes transportReady(Attributes attributes) { + assertFalse(terminated.isDone()); + return attributes; + } + + @Override + public void transportTerminated() { + assertTrue(terminated.set(null)); + } + + public boolean waitForTermination(long timeout, TimeUnit unit) throws InterruptedException { + return AbstractTransportTest.waitForFuture(terminated, timeout, unit); + } + + public boolean isTerminated() { + return terminated.isDone(); + } + + public StreamCreation takeStreamOrFail(long timeout, TimeUnit unit) throws InterruptedException { + StreamCreation stream = streams.poll(timeout, unit); + if (stream == null) { + fail("Timed out waiting for server stream"); + } + return stream; + } + + public static class StreamCreation { + public final ServerStream stream; + public final String method; + public final Metadata headers; + public final ServerStreamListenerBase listener; + + public StreamCreation( + ServerStream stream, String method, Metadata headers, ServerStreamListenerBase listener) { + this.stream = stream; + this.method = method; + this.headers = headers; + this.listener = listener; + } + } +} diff --git a/core/src/testFixtures/java/io/grpc/internal/PickFirstLoadBalancerProviderAccessor.java b/core/src/testFixtures/java/io/grpc/internal/PickFirstLoadBalancerProviderAccessor.java new file mode 100644 index 00000000000..a6e94df03c2 --- /dev/null +++ b/core/src/testFixtures/java/io/grpc/internal/PickFirstLoadBalancerProviderAccessor.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +/** + * Accessor for PickFirstLoadBalancerProvider, allowing access only during tests. + */ +public final class PickFirstLoadBalancerProviderAccessor { + private PickFirstLoadBalancerProviderAccessor() {} + + public static void setEnableNewPickFirst(boolean enableNewPickFirst) { + PickFirstLoadBalancerProvider.enableNewPickFirst = enableNewPickFirst; + } +} diff --git a/core/src/testFixtures/java/io/grpc/internal/ReadableBufferTestBase.java b/core/src/testFixtures/java/io/grpc/internal/ReadableBufferTestBase.java index 202fb7ee8a4..2262f0466f7 100644 --- a/core/src/testFixtures/java/io/grpc/internal/ReadableBufferTestBase.java +++ b/core/src/testFixtures/java/io/grpc/internal/ReadableBufferTestBase.java @@ -21,7 +21,6 @@ import static org.junit.Assert.assertEquals; import java.io.ByteArrayOutputStream; -import java.nio.Buffer; import java.nio.ByteBuffer; import java.util.Arrays; import org.junit.Assume; @@ -83,30 +82,6 @@ public void partialReadToStreamShouldSucceed() throws Exception { assertEquals(msg.length() - 2, buffer.readableBytes()); } - @Test - public void readToByteBufferShouldSucceed() { - ReadableBuffer buffer = buffer(); - ByteBuffer byteBuffer = ByteBuffer.allocate(msg.length()); - buffer.readBytes(byteBuffer); - ((Buffer) byteBuffer).flip(); - byte[] array = new byte[msg.length()]; - byteBuffer.get(array); - assertArrayEquals(msg.getBytes(UTF_8), array); - assertEquals(0, buffer.readableBytes()); - } - - @Test - public void partialReadToByteBufferShouldSucceed() { - ReadableBuffer buffer = buffer(); - ByteBuffer byteBuffer = ByteBuffer.allocate(2); - buffer.readBytes(byteBuffer); - ((Buffer) byteBuffer).flip(); - byte[] array = new byte[2]; - byteBuffer.get(array); - assertArrayEquals(new byte[]{'h', 'e'}, array); - assertEquals(msg.length() - 2, buffer.readableBytes()); - } - @Test public void partialReadToReadableBufferShouldSucceed() { ReadableBuffer buffer = buffer(); diff --git a/core/src/testFixtures/java/io/grpc/internal/ServerStreamListenerBase.java b/core/src/testFixtures/java/io/grpc/internal/ServerStreamListenerBase.java new file mode 100644 index 00000000000..aaa70600542 --- /dev/null +++ b/core/src/testFixtures/java/io/grpc/internal/ServerStreamListenerBase.java @@ -0,0 +1,99 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.fail; + +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Status; +import java.io.InputStream; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * A {@link ServerStreamListener} that helps you write blocking unit tests. + */ +public class ServerStreamListenerBase implements ServerStreamListener { + public final BlockingQueue messageQueue = new LinkedBlockingQueue<>(); + // Would have used Void instead of Object, but null elements are not allowed + private final BlockingQueue readyQueue = new LinkedBlockingQueue<>(); + private final CountDownLatch halfClosedLatch = new CountDownLatch(1); + private final SettableFuture status = SettableFuture.create(); + + public boolean awaitOnReady(int timeout, TimeUnit unit) throws Exception { + return readyQueue.poll(timeout, unit) != null; + } + + public boolean awaitOnReadyAndDrain(int timeout, TimeUnit unit) throws Exception { + if (!awaitOnReady(timeout, unit)) { + return false; + } + // Throw the rest away + readyQueue.drainTo(Lists.newArrayList()); + return true; + } + + public boolean awaitHalfClosed(int timeout, TimeUnit unit) throws Exception { + return halfClosedLatch.await(timeout, unit); + } + + public boolean isHalfClosed() { + return halfClosedLatch.getCount() == 0; + } + + public Status awaitClose(int timeout, TimeUnit unit) throws Exception { + return status.get(timeout, unit); + } + + @Override + public void messagesAvailable(MessageProducer producer) { + if (status.isDone()) { + fail("messagesAvailable invoked after closed"); + } + InputStream message; + while ((message = producer.next()) != null) { + messageQueue.add(message); + } + } + + @Override + public void onReady() { + if (status.isDone()) { + fail("onReady invoked after closed"); + } + readyQueue.add(new Object()); + } + + @Override + public void halfClosed() { + if (status.isDone()) { + fail("halfClosed invoked after closed"); + } + halfClosedLatch.countDown(); + } + + @Override + public void closed(Status status) { + if (this.status.isDone()) { + fail("closed invoked more than once"); + } + this.status.set(status); + } +} diff --git a/cronet/build.gradle b/cronet/build.gradle index 3252a9d249b..e096761ddd2 100644 --- a/cronet/build.gradle +++ b/cronet/build.gradle @@ -8,14 +8,13 @@ description = "gRPC: Cronet Android" repositories { google() - mavenCentral() } android { - namespace 'io.grpc.cronet' + namespace = 'io.grpc.cronet' compileSdkVersion 33 defaultConfig { - minSdkVersion 21 + minSdkVersion 23 targetSdkVersion 33 versionCode 1 versionName "1.0" @@ -47,6 +46,7 @@ dependencies { libraries.cronet.api implementation project(':grpc-core') implementation libraries.guava + implementation 'org.checkerframework:checker-qual:3.49.5' testImplementation project(':grpc-testing') testImplementation libraries.cronet.embedded diff --git a/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java b/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java index f42dabdd55a..7ea1bc891c2 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java @@ -21,7 +21,6 @@ import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import android.net.Network; -import android.os.Build; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.MoreExecutors; @@ -340,9 +339,7 @@ public BidirectionalStream.Builder newBidirectionalStreamBuilder( builder.setTrafficStatsUid(trafficStatsUid); } if (network != null) { - if (Build.VERSION.SDK_INT >= 23) { - builder.bindToNetwork(network.getNetworkHandle()); - } + builder.bindToNetwork(network.getNetworkHandle()); } return builder; } diff --git a/cronet/src/main/java/io/grpc/cronet/CronetClientStream.java b/cronet/src/main/java/io/grpc/cronet/CronetClientStream.java index 9ae97652316..07bbb953489 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetClientStream.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetClientStream.java @@ -25,6 +25,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.io.BaseEncoding; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.InternalMetadata; @@ -50,7 +51,6 @@ import java.util.Map; import java.util.concurrent.Executor; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import org.chromium.net.BidirectionalStream; import org.chromium.net.CronetException; import org.chromium.net.UrlResponseInfo; @@ -59,7 +59,6 @@ * Client stream for the cronet transport. */ class CronetClientStream extends AbstractClientStream { - private static final int READ_BUFFER_CAPACITY = 4 * 1024; private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0); private static final String LOG_TAG = "grpc-java-cronet"; @@ -69,6 +68,12 @@ class CronetClientStream extends AbstractClientStream { static final CallOptions.Key> CRONET_ANNOTATIONS_KEY = CallOptions.Key.create("cronet-annotations"); + /** + * Sets the read buffer size which the GRPC layer will use to read data from Cronet. Higher buffer + * size leads to less overhead but more memory consumption. The current default value is 4KB. + */ + static final CallOptions.Key CRONET_READ_BUFFER_SIZE_KEY = + CallOptions.Key.createWithDefault("cronet-read-buffer-size", 4 * 1024); private final String url; private final String userAgent; @@ -85,6 +90,8 @@ class CronetClientStream extends AbstractClientStream { private final Collection annotations; private final TransportState state; private final Sink sink = new Sink(); + @VisibleForTesting + final int readBufferSize; private StreamBuilderFactory streamFactory; CronetClientStream( @@ -120,6 +127,7 @@ class CronetClientStream extends AbstractClientStream { this.annotations = callOptions.getOption(CRONET_ANNOTATIONS_KEY); this.state = new TransportState(maxMessageSize, statsTraceCtx, lock, transportTracer, callOptions); + this.readBufferSize = callOptions.getOption(CRONET_READ_BUFFER_SIZE_KEY); // Tests expect the "plain" deframer behavior, not MigratingDeframer // https://github.com/grpc/grpc-java/issues/7140 @@ -309,7 +317,7 @@ public void bytesRead(int processedBytes) { if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) { Log.v(LOG_TAG, "BidirectionalStream.read"); } - stream.read(ByteBuffer.allocateDirect(READ_BUFFER_CAPACITY)); + stream.read(ByteBuffer.allocateDirect(readBufferSize)); } } @@ -362,7 +370,6 @@ private static boolean isApplicationHeader(String key) { private void setGrpcHeaders(BidirectionalStream.Builder builder) { // Psuedo-headers are set by cronet. // All non-pseudo headers must come after pseudo headers. - // TODO(ericgribkoff): remove this and set it on CronetEngine after crbug.com/588204 gets fixed. builder.addHeader(USER_AGENT_KEY.name(), userAgent); builder.addHeader(CONTENT_TYPE_KEY.name(), GrpcUtil.CONTENT_TYPE_GRPC); builder.addHeader("te", GrpcUtil.TE_TRAILERS); @@ -430,7 +437,7 @@ public void onResponseHeadersReceived(BidirectionalStream stream, UrlResponseInf Log.v(LOG_TAG, "BidirectionalStream.read"); } reportHeaders(info.getAllHeadersAsList(), false); - stream.read(ByteBuffer.allocateDirect(READ_BUFFER_CAPACITY)); + stream.read(ByteBuffer.allocateDirect(readBufferSize)); } @Override diff --git a/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java b/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java index b0b18620d0c..99eb88737aa 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java @@ -19,6 +19,7 @@ import com.google.common.base.Preconditions; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; @@ -33,6 +34,7 @@ import io.grpc.internal.ConnectionClientTransport; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; import java.net.InetSocketAddress; @@ -42,7 +44,6 @@ import java.util.Set; import java.util.concurrent.Executor; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * A cronet-based {@link ConnectionClientTransport} implementation. @@ -229,7 +230,7 @@ private void startGoAway(Status status) { startedGoAway = true; } - listener.transportShutdown(status); + listener.transportShutdown(status, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); synchronized (lock) { goAway = true; diff --git a/cronet/src/main/java/io/grpc/cronet/InternalCronetCallOptions.java b/cronet/src/main/java/io/grpc/cronet/InternalCronetCallOptions.java index e7c4144e63a..9261a0a8f4b 100644 --- a/cronet/src/main/java/io/grpc/cronet/InternalCronetCallOptions.java +++ b/cronet/src/main/java/io/grpc/cronet/InternalCronetCallOptions.java @@ -36,6 +36,18 @@ public static CallOptions withAnnotation(CallOptions callOptions, Object annotat return CronetClientStream.withAnnotation(callOptions, annotation); } + public static CallOptions withReadBufferSize(CallOptions callOptions, int size) { + return callOptions.withOption(CronetClientStream.CRONET_READ_BUFFER_SIZE_KEY, size); + } + + /** + * Returns Cronet read buffer size for gRPC included in the given {@code callOptions}. Read + * buffer can be customized via {@link #withReadBufferSize(CallOptions, int)}. + */ + public static int getReadBufferSize(CallOptions callOptions) { + return callOptions.getOption(CronetClientStream.CRONET_READ_BUFFER_SIZE_KEY); + } + /** * Returns Cronet annotations for gRPC included in the given {@code callOptions}. Annotations * are attached via {@link #withAnnotation(CallOptions, Object)}. diff --git a/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java b/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java index 41f48bc03bb..be437b3c80b 100644 --- a/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java +++ b/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java @@ -16,7 +16,9 @@ package io.grpc.cronet; +import static io.grpc.cronet.CronetClientStream.CRONET_READ_BUFFER_SIZE_KEY; import static io.grpc.internal.GrpcUtil.TIMER_SERVICE; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; @@ -92,6 +94,41 @@ public void alwaysUsePut_defaultsToFalse() throws Exception { assertFalse(stream.idempotent); } + @Test + public void channelBuilderReadBufferSize_defaultsTo4Kb() throws Exception { + CronetChannelBuilder builder = CronetChannelBuilder.forAddress("address", 1234, mockEngine); + CronetTransportFactory transportFactory = + (CronetTransportFactory) builder.buildTransportFactory(); + CronetClientTransport transport = + (CronetClientTransport) + transportFactory.newClientTransport( + new InetSocketAddress("localhost", 443), + new ClientTransportOptions(), + channelLogger); + CronetClientStream stream = transport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); + + assertEquals(4 * 1024, stream.readBufferSize); + } + + @Test + public void channelBuilderReadBufferSize_changeReflected() throws Exception { + CronetChannelBuilder builder = CronetChannelBuilder.forAddress("address", 1234, mockEngine); + CronetTransportFactory transportFactory = + (CronetTransportFactory) builder.buildTransportFactory(); + CronetClientTransport transport = + (CronetClientTransport) + transportFactory.newClientTransport( + new InetSocketAddress("localhost", 443), + new ClientTransportOptions(), + channelLogger); + CronetClientStream stream = transport.newStream( + method, new Metadata(), + CallOptions.DEFAULT.withOption(CRONET_READ_BUFFER_SIZE_KEY, 32 * 1024), tracers); + + assertEquals(32 * 1024, stream.readBufferSize); + } + @Test public void scheduledExecutorService_default() { CronetChannelBuilder builder = CronetChannelBuilder.forAddress("address", 1234, mockEngine); diff --git a/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java b/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java index 03c31f93329..3a79cc0b6a8 100644 --- a/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java +++ b/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java @@ -34,6 +34,7 @@ import io.grpc.Status; import io.grpc.cronet.CronetChannelBuilder.StreamBuilderFactory; import io.grpc.internal.ClientStreamListener; +import io.grpc.internal.DisconnectError; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.ManagedClientTransport; import io.grpc.internal.TransportTracer; @@ -128,7 +129,8 @@ public void shutdownTransport() throws Exception { BidirectionalStream.Callback callback2 = callbackCaptor.getValue(); // Shut down the transport. transportShutdown should be called immediately. transport.shutdown(); - verify(clientTransportListener).transportShutdown(any(Status.class)); + verify(clientTransportListener).transportShutdown(any(Status.class), + any(DisconnectError.class)); // Have two live streams. Transport has not been terminated. verify(clientTransportListener, times(0)).transportTerminated(); diff --git a/documentation/android-binderchannel-status-codes.md b/documentation/android-binderchannel-status-codes.md index dda0220bf8a..fae4ef406af 100644 --- a/documentation/android-binderchannel-status-codes.md +++ b/documentation/android-binderchannel-status-codes.md @@ -23,51 +23,66 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 1 + 0 - Server app not installed + Server app not visible. + + bindService() returns false - bindService() returns false +

UNIMPLEMENTED

“The operation is not implemented or is not supported / enabled in this service.” + + Give up - This is an error in the client manifest. + + + + 1 -

UNIMPLEMENTED

“The operation is not implemented or is not supported / enabled in this service.” + + Safer Intents violation. - Direct the user to install/reinstall the server app. + Direct the user to install/reinstall the server app. 2 - Old version of the server app doesn’t declare the target android.app.Service in its manifest. + Server app not installed 3 - Target android.app.Service is disabled + Old version of the server app doesn’t declare the target android.app.Service in its manifest. 4 - The whole server app is disabled + Target android.app.Service is disabled 5 - Server app predates the Android M permissions model and the user must review and approve some newly requested permissions before it can run. + The whole server app is disabled 6 + Server app predates the Android M permissions model and the user must review and approve some newly requested permissions before it can run. + + + + 7 + Target android.app.Service doesn’t recognize grpc binding Intent (old version of server app?) onNullBinding() ServiceConnection callback - 7 + 8 Method not found on the io.grpc.Server (old version of server app?) @@ -75,13 +90,13 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 8 + 9 Request cardinality violation (old version of server app expects unary rather than streaming, say) - 9 + 10 Old version of the server app exposes target android.app.Service but doesn’t android:export it. @@ -90,9 +105,11 @@ Consider the table that follows as an BinderChannel-specific addendum to the “

PERMISSION_DENIED

“The caller does not have permission to execute the specified operation …” + Direct the user to update the server app in the hopes that a newer version fixes this error in its manifest. + - 10 + 11 Target android.app.Service requires an <android:permission> that client doesn’t hold. @@ -100,7 +117,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 11 + 12 Violations of the security policy for miscellaneous Android features like android:isolatedProcess, android:externalService, android:singleUser, instant apps, BIND_TREAT_LIKE_ACTIVITY, etc, @@ -108,7 +125,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 12 + 13 Calling Android UID not allowed by ServerSecurityPolicy @@ -116,13 +133,13 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 13 + 14 Server Android UID not allowed by client’s SecurityPolicy - 14 + 15 Server process crashed or killed with request in flight. @@ -144,7 +161,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 15 + 16 Server app is currently being upgraded to a new version @@ -152,13 +169,13 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 16 + 17 The whole server app or the target android.app.Service was disabled - 17 + 18 Binder transaction buffer overflow @@ -166,7 +183,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 18 + 19 Source Context for bindService() is destroyed with a request in flight @@ -178,11 +195,11 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ Give up for now.

-(Re. 18: The caller can try again later when the user opens the source Activity or restarts the source Service) +(Re. 19: The caller can try again later when the user opens the source Activity or restarts the source Service) - 19 + 20 Client application cancelled the request @@ -190,7 +207,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 19 + 21 Bug in Android itself or the way the io.grpc.binder transport uses it. @@ -208,7 +225,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 20 + 22 Flow-control protocol violation @@ -216,7 +233,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 21 + 23 Can’t parse request/response proto @@ -226,27 +243,27 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ ### Ambiguity -We say a status code is ambiguous if it maps to two error cases that reasonable clients want to handle differently. For instance, a client may have good reasons to handle error cases 9 and 10 above differently. But they can’t do so based on status code alone because those error cases map to the same one. +We say a status code is ambiguous if it maps to two error cases that reasonable clients want to handle differently. For instance, a client may have good reasons to handle error cases 10 and 11 above differently. But they can’t do so based on status code alone because those error cases map to the same one. -In contrast, for example, even though error case 18 and 19 both map to the status code (`CANCELLED`), they are not ambiguous because we see no reason that clients would want to distinguish them. In both cases, clients will simply give up on the request. +In contrast, for example, even though error case 19 and 20 both map to the status code (`CANCELLED`), they are not ambiguous because we see no reason that clients would want to distinguish them. In both cases, clients will simply give up on the request. #### Ambiguity of PERMISSION_DENIED and Mitigations The mapping above has only one apparently ambiguous status code: `PERMISSION_DENIED`. However, this isn’t so bad because of the following: -The use of ``s for inter-app IPC access control (error case 10) is uncommon. Instead, we recommend that server apps only allow IPC from a limited set of client apps known in advance and identified by signature. +The use of ``s for inter-app IPC access control (error case 11) is uncommon. Instead, we recommend that server apps only allow IPC from a limited set of client apps known in advance and identified by signature. -However, there may be gRPC server apps that want to use custom <android:permission>’s to let the end user decide which arbitrary other apps can make use of its gRPC services. In that case, clients should preempt error case 10 simply by [checking whether they hold the required permissions](https://developer.android.com/training/permissions/requesting) before sending a request. +However, there may be gRPC server apps that want to use custom <android:permission>’s to let the end user decide which arbitrary other apps can make use of its gRPC services. In that case, clients should preempt error case 11 simply by [checking whether they hold the required permissions](https://developer.android.com/training/permissions/requesting) before sending a request. -Server apps can avoid error case 9 by never reusing an android.app.Service as a gRPC host if it has ever been android:exported=false in some previous app version. Instead they should simply create a new android.app.Service for this purpose. +Server apps can avoid error case 10 by never reusing an android.app.Service as a gRPC host if it has ever been android:exported=false in some previous app version. Instead they should simply create a new android.app.Service for this purpose. -Only error cases 11 - 13 remain, making `PERMISSION_DENIED` unambiguous for the purpose of error handling. Reasonable client apps can handle it in a generic way by displaying an error message and/or proceeding with degraded functionality. +Only error cases 12 - 14 remain, making `PERMISSION_DENIED` unambiguous for the purpose of error handling. Reasonable client apps can handle it in a generic way by displaying an error message and/or proceeding with degraded functionality. #### Non-Ambiguity of UNIMPLEMENTED -The `UNIMPLEMENTED` status code corresponds to quite a few different problems with the server app: It’s either not installed, too old, or disabled in whole or in part. Despite the diversity of underlying error cases, we believe most client apps will and should handle `UNIMPLEMENTED` in the same way: by sending the user to the app store to (re)install the server app. Reinstalling might be overkill for the disabled cases but most end users don't know what it means to enable/disable an app and there’s neither enough space in a UI dialog nor enough reader attention to explain it. Reinstalling is something users likely already understand and very likely to cure problems 1-8. +The `UNIMPLEMENTED` status code corresponds to quite a few different problems with the server app: It’s either not installed, too old, misconfigured, or disabled in whole or in part. Despite the diversity of underlying error cases, we believe most client apps will and should handle `UNIMPLEMENTED` in the same way: by sending the user to the app store to (re)install the server app. Reinstalling might be overkill for the disabled cases but most end users don't know what it means to enable/disable an app and there’s neither enough space in a UI dialog nor enough reader attention to explain it. Reinstalling is something users likely already understand and likely to cure problems 0-9 (once a fixed version of the server is available). ## Detailed Discussion of Binder Failure Modes @@ -315,6 +332,8 @@ According to a review of the AOSP source code, there are in fact several cases: 1. The target package is not installed 2. The target package is installed but does not declare the target Service in its manifest. 3. The target package requests dangerous permissions but targets sdk <= M and therefore requires a permissions review, but the caller is not running in the foreground and so it would be inappropriate to launch the review UI. +4. The target package is not visible to the client due to [Android 11 package visibility rules](https://developer.android.com/training/package-visibility). +5. One of the new [Safer Intents](https://developer.android.com/about/versions/15/behavior-changes-15#safer-intents) rules is violated. Most commonly, the bind `Intent` specifies a `ComponentName` explicitly but doesn't match any of its <intent-filter>s. Status code mapping: **UNIMPLEMENTED** @@ -322,6 +341,7 @@ Status code mapping: **UNIMPLEMENTED** Unfortunately `UNIMPLEMENTED` doesn’t capture (3) but none of the other canonical status codes do either and we expect this case to be extremely rare. +(4) and (5) are intentially indistinguishable from (1) by Android design so we can't handle them differently. However, as an error in its own manifest, (4) isn't something a reasonable client would handle at runtime anyway. (5) is an error in the server manifest and so, just like the other cases, the best practice for handling it is to send the user to the app store in the hope that the server can be updated with a fix. ### bindService() throws SecurityException @@ -382,4 +402,4 @@ Android’s Parcel class exposes a mechanism for marshalling certain types of `R The calling Activity or Service Context might be destroyed with a gRPC request in flight. Apps should cease operations when the Context hosting it goes away and this includes cancelling any outstanding RPCs. -Status code mapping: **CANCELLED** \ No newline at end of file +Status code mapping: **CANCELLED** diff --git a/documentation/server-reflection-tutorial.md b/documentation/server-reflection-tutorial.md index 5fad5a22333..f452174738a 100644 --- a/documentation/server-reflection-tutorial.md +++ b/documentation/server-reflection-tutorial.md @@ -10,9 +10,9 @@ proto-based services. ## Enable Server Reflection gRPC-Java Server Reflection is implemented by -`io.grpc.protobuf.services.ProtoReflectionService` in the `grpc-services` +`io.grpc.protobuf.services.ProtoReflectionServiceV1` in the `grpc-services` package. To enable server reflection, you need to add the -`ProtoReflectionService` to your gRPC server. +`ProtoReflectionServiceV1` to your gRPC server. For example, to enable server reflection in `examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java`, we @@ -28,14 +28,14 @@ need to make the following changes: + compile "io.grpc:grpc-services:${grpcVersion}" compile "io.grpc:grpc-stub:${grpcVersion}" - testCompile "junit:junit:4.12" + testCompile "junit:junit:4.13.2" --- a/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java +++ b/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java @@ -33,6 +33,7 @@ package io.grpc.examples.helloworld; import io.grpc.Server; import io.grpc.ServerBuilder; -+import io.grpc.protobuf.services.ProtoReflectionService; ++import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.stub.StreamObserver; import java.io.IOException; import java.util.logging.Logger; @@ -43,7 +43,7 @@ need to make the following changes: int port = 50051; server = ServerBuilder.forPort(port) .addService(new GreeterImpl()) -+ .addService(ProtoReflectionService.newInstance()) ++ .addService(ProtoReflectionServiceV1.newInstance()) .build() .start(); logger.info("Server started, listening on " + port); diff --git a/examples/.bazelrc b/examples/.bazelrc index 554440cfe3d..53485cb9743 100644 --- a/examples/.bazelrc +++ b/examples/.bazelrc @@ -1 +1 @@ -build --cxxopt=-std=c++14 --host_cxxopt=-std=c++14 +build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 diff --git a/examples/BUILD.bazel b/examples/BUILD.bazel index 3a0936780a0..e3ef8c5ac5d 100644 --- a/examples/BUILD.bazel +++ b/examples/BUILD.bazel @@ -1,5 +1,8 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@io_grpc_grpc_java//:java_grpc_library.bzl", "java_grpc_library") +load("@rules_java//java:java_binary.bzl", "java_binary") +load("@rules_java//java:java_library.bzl", "java_library") proto_library( name = "helloworld_proto", diff --git a/examples/MODULE.bazel b/examples/MODULE.bazel index 60bed40f349..2e90a63c219 100644 --- a/examples/MODULE.bazel +++ b/examples/MODULE.bazel @@ -1,22 +1,25 @@ -bazel_dep(name = "googleapis", repo_name = "com_google_googleapis", version = "0.0.0-20240326-1c8d509c5") -bazel_dep(name = "grpc-java", repo_name = "io_grpc_grpc_java", version = "1.66.0-SNAPSHOT") # CURRENT_GRPC_VERSION -bazel_dep(name = "grpc-proto", repo_name = "io_grpc_grpc_proto", version = "0.0.0-20240627-ec30f58") -bazel_dep(name = "protobuf", repo_name = "com_google_protobuf", version = "23.1") +bazel_dep(name = "grpc-java", version = "1.81.0-SNAPSHOT", repo_name = "io_grpc_grpc_java") # CURRENT_GRPC_VERSION +bazel_dep(name = "rules_java", version = "9.3.0") +bazel_dep(name = "grpc-proto", version = "0.0.0-20240627-ec30f58", repo_name = "io_grpc_grpc_proto") +bazel_dep(name = "protobuf", version = "33.1", repo_name = "com_google_protobuf") bazel_dep(name = "rules_jvm_external", version = "6.0") -bazel_dep(name = "rules_proto", version = "5.3.0-21.7") -# Do not use this override in your own MODULE.bazel. Use a version from BCR +# Do not use this override in your own MODULE.bazel. It is unnecessary when +# using a version from BCR. Be aware the gRPC Java team does not update the +# BCR for new releases, so you may need to create a PR for the BCR to add the +# version. To not use the BCR, you could use: +# +# git_override( +# module_name = "grpc-java", +# remote = "https://github.com/grpc/grpc-java.git", +# tag = "v", +# ) local_path_override( module_name = "grpc-java", path = "..", ) -switched_rules = use_extension("@com_google_googleapis//:extensions.bzl", "switched_rules") - -switched_rules.use_languages(java = True) - maven = use_extension("@rules_jvm_external//:extensions.bzl", "maven") - use_repo(maven, "maven") maven.install( diff --git a/examples/README.md b/examples/README.md index b51d560d7bb..91fde2c045c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -27,114 +27,32 @@ before trying out the examples. - [Json serialization](src/main/java/io/grpc/examples/advanced) --

- Hedging - - The [hedging example](src/main/java/io/grpc/examples/hedging) demonstrates that enabling hedging - can reduce tail latency. (Users should note that enabling hedging may introduce other overhead; - and in some scenarios, such as when some server resource gets exhausted for a period of time and - almost every RPC during that time has high latency or fails, hedging may make things worse. - Setting a throttle in the service config is recommended to protect the server from too many - inappropriate retry or hedging requests.) - - The server and the client in the example are basically the same as those in the - [hello world](src/main/java/io/grpc/examples/helloworld) example, except that the server mimics a - long tail of latency, and the client sends 2000 requests and can turn on and off hedging. - - To mimic the latency, the server randomly delays the RPC handling by 2 seconds at 10% chance, 5 - seconds at 5% chance, and 10 seconds at 1% chance. - - When running the client enabling the following hedging policy - - ```json - "hedgingPolicy": { - "maxAttempts": 3, - "hedgingDelay": "1s" - } - ``` - Then the latency summary in the client log is like the following - - ```text - Total RPCs sent: 2,000. Total RPCs failed: 0 - [Hedging enabled] - ======================== - 50% latency: 0ms - 90% latency: 6ms - 95% latency: 1,003ms - 99% latency: 2,002ms - 99.9% latency: 2,011ms - Max latency: 5,272ms - ======================== - ``` - - See [the section below](#to-build-the-examples) for how to build and run the example. The - executables for the server and the client are `hedging-hello-world-server` and - `hedging-hello-world-client`. - - To disable hedging, set environment variable `DISABLE_HEDGING_IN_HEDGING_EXAMPLE=true` before - running the client. That produces a latency summary in the client log like the following - - ```text - Total RPCs sent: 2,000. Total RPCs failed: 0 - [Hedging disabled] - ======================== - 50% latency: 0ms - 90% latency: 2,002ms - 95% latency: 5,002ms - 99% latency: 10,004ms - 99.9% latency: 10,007ms - Max latency: 10,007ms - ======================== - ``` - -
- --
- Retrying - - The [retrying example](src/main/java/io/grpc/examples/retrying) provides a HelloWorld gRPC client & - server which demos the effect of client retry policy configured on the [ManagedChannel]( - ../api/src/main/java/io/grpc/ManagedChannel.java) via [gRPC ServiceConfig]( - https://github.com/grpc/grpc/blob/master/doc/service_config.md). Retry policy implementation & - configuration details are outlined in the [proposal](https://github.com/grpc/proposal/blob/master/A6-client-retries.md). - - This retrying example is very similar to the [hedging example](src/main/java/io/grpc/examples/hedging) in its setup. - The [RetryingHelloWorldServer](src/main/java/io/grpc/examples/retrying/RetryingHelloWorldServer.java) responds with - a status UNAVAILABLE error response to a specified percentage of requests to simulate server resource exhaustion and - general flakiness. The [RetryingHelloWorldClient](src/main/java/io/grpc/examples/retrying/RetryingHelloWorldClient.java) makes - a number of sequential requests to the server, several of which will be retried depending on the configured policy in - [retrying_service_config.json](src/main/resources/io/grpc/examples/retrying/retrying_service_config.json). Although - the requests are blocking unary calls for simplicity, these could easily be changed to future unary calls in order to - test the result of request concurrency with retry policy enabled. - - One can experiment with the [RetryingHelloWorldServer](src/main/java/io/grpc/examples/retrying/RetryingHelloWorldServer.java) - failure conditions to simulate server throttling, as well as alter policy values in the [retrying_service_config.json]( - src/main/resources/io/grpc/examples/retrying/retrying_service_config.json) to see their effects. To disable retrying - entirely, set environment variable `DISABLE_RETRYING_IN_RETRYING_EXAMPLE=true` before running the client. - Disabling the retry policy should produce many more failed gRPC calls as seen in the output log. - - See [the section below](#to-build-the-examples) for how to build and run the example. The - executables for the server and the client are `retrying-hello-world-server` and - `retrying-hello-world-client`. - -
- --
- Health Service - - The [health service example](src/main/java/io/grpc/examples/healthservice) - provides a HelloWorld gRPC server that doesn't like short names along with a - health service. It also provides a client application which makes HelloWorld - calls and checks the health status. - - The client application also shows how the round robin load balancer can - utilize the health status to avoid making calls to a service that is - not actively serving. -
+- [Hedging example](src/main/java/io/grpc/examples/hedging) +- [Retrying example](src/main/java/io/grpc/examples/retrying) + +- [Health Service example](src/main/java/io/grpc/examples/healthservice) - [Keep Alive](src/main/java/io/grpc/examples/keepalive) +- [Cancellation](src/main/java/io/grpc/examples/cancellation) + +- [Custom Load Balance](src/main/java/io/grpc/examples/customloadbalance) + +- [Deadline](src/main/java/io/grpc/examples/deadline) + +- [Error Details](src/main/java/io/grpc/examples/errordetails) + +- [GRPC Proxy](src/main/java/io/grpc/examples/grpcproxy) + +- [Load Balance](src/main/java/io/grpc/examples/loadbalance) + +- [Multiplex](src/main/java/io/grpc/examples/multiplex) + +- [Name Resolve](src/main/java/io/grpc/examples/nameresolve) + +- [Pre-Serialized Messages](src/main/java/io/grpc/examples/preserialized) + ### To build the examples 1. **[Install gRPC Java library SNAPSHOT locally, including code generation plugin](../COMPILING.md) (Only need this step for non-released versions, e.g. master HEAD).** @@ -235,9 +153,9 @@ Example bugs not caught by mocked stub tests include: For testing a gRPC client, create the client with a real stub using an -[InProcessChannel](../core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java), +[InProcessChannel](../inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java), and test it against an -[InProcessServer](../core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java) +[InProcessServer](../inprocess/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java) with a mock/fake service implementation. For testing a gRPC server, create the server as an InProcessServer, diff --git a/examples/WORKSPACE b/examples/WORKSPACE index 170e06a90c7..1387cc4cf12 100644 --- a/examples/WORKSPACE +++ b/examples/WORKSPACE @@ -14,38 +14,57 @@ local_repository( path = "..", ) -http_archive( - name = "rules_jvm_external", - sha256 = "d31e369b854322ca5098ea12c69d7175ded971435e55c18dd9dd5f29cc5249ac", - strip_prefix = "rules_jvm_external-5.3", - url = "https://github.com/bazelbuild/rules_jvm_external/releases/download/5.3/rules_jvm_external-5.3.tar.gz", -) - -load("@rules_jvm_external//:defs.bzl", "maven_install") -load("@io_grpc_grpc_java//:repositories.bzl", "IO_GRPC_GRPC_JAVA_ARTIFACTS") -load("@io_grpc_grpc_java//:repositories.bzl", "IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS") -load("@io_grpc_grpc_java//:repositories.bzl", "grpc_java_repositories") +load("@io_grpc_grpc_java//:repositories.bzl", "IO_GRPC_GRPC_JAVA_ARTIFACTS", "IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS", "grpc_java_repositories") grpc_java_repositories() +http_archive( + name = "rules_java", + sha256 = "47632cc506c858011853073449801d648e10483d4b50e080ec2549a4b2398960", + urls = [ + "https://github.com/bazelbuild/rules_java/releases/download/8.15.2/rules_java-8.15.2.tar.gz", + ], +) + # Protobuf now requires C++14 or higher, which requires Bazel configuration # outside the WORKSPACE. See .bazelrc in this directory. -load("@com_google_protobuf//:protobuf_deps.bzl", "PROTOBUF_MAVEN_ARTIFACTS") -load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") +load("@com_google_protobuf//:protobuf_deps.bzl", "PROTOBUF_MAVEN_ARTIFACTS", "protobuf_deps") protobuf_deps() -load("@envoy_api//bazel:repositories.bzl", "api_dependencies") +load("@rules_java//java:rules_java_deps.bzl", "compatibility_proxy_repo", "rules_java_dependencies") + +rules_java_dependencies() + +load("@bazel_features//:deps.bzl", "bazel_features_deps") -api_dependencies() +bazel_features_deps() + +compatibility_proxy_repo() + +load("@bazel_jar_jar//:jar_jar.bzl", "jar_jar_repositories") + +jar_jar_repositories() + +load("@rules_python//python:repositories.bzl", "py_repositories") + +py_repositories() load("@com_google_googleapis//:repository_rules.bzl", "switched_rules_by_language") switched_rules_by_language( name = "com_google_googleapis_imports", - java = True, ) +http_archive( + name = "rules_jvm_external", + sha256 = "d31e369b854322ca5098ea12c69d7175ded971435e55c18dd9dd5f29cc5249ac", + strip_prefix = "rules_jvm_external-5.3", + url = "https://github.com/bazelbuild/rules_jvm_external/releases/download/5.3/rules_jvm_external-5.3.tar.gz", +) + +load("@rules_jvm_external//:defs.bzl", "maven_install") + maven_install( artifacts = [ "com.google.api.grpc:grpc-google-cloud-pubsub-v1:0.1.24", diff --git a/examples/android/clientcache/app/build.gradle b/examples/android/clientcache/app/build.gradle index 6b5b966e7f6..67110a78c43 100644 --- a/examples/android/clientcache/app/build.gradle +++ b/examples/android/clientcache/app/build.gradle @@ -10,9 +10,8 @@ android { defaultConfig { applicationId "io.grpc.clientcacheexample" - minSdkVersion 21 + minSdkVersion 23 targetSdkVersion 33 - multiDexEnabled true versionCode 1 versionName "1.0" testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" @@ -34,7 +33,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -54,12 +53,11 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'org.apache.tomcat:annotations-api:6.0.53' + implementation 'io.grpc:grpc-okhttp:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION testImplementation 'junit:junit:4.13.2' - testImplementation 'com.google.truth:truth:1.1.5' - testImplementation 'io.grpc:grpc-testing:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + testImplementation 'com.google.truth:truth:1.4.5' + testImplementation 'io.grpc:grpc-testing:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/clientcache/build.gradle b/examples/android/clientcache/build.gradle index 67d25905bbc..6db6a9bced1 100644 --- a/examples/android/clientcache/build.gradle +++ b/examples/android/clientcache/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:7.4.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.4" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.5" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/clientcache/settings.gradle b/examples/android/clientcache/settings.gradle index e7b4def49cb..6208d70e838 100644 --- a/examples/android/clientcache/settings.gradle +++ b/examples/android/clientcache/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + include ':app' diff --git a/examples/android/helloworld/app/build.gradle b/examples/android/helloworld/app/build.gradle index 4edbcb14612..d20bd03d1fc 100644 --- a/examples/android/helloworld/app/build.gradle +++ b/examples/android/helloworld/app/build.gradle @@ -10,7 +10,7 @@ android { defaultConfig { applicationId "io.grpc.helloworldexample" - minSdkVersion 21 + minSdkVersion 23 targetSdkVersion 33 versionCode 1 versionName "1.0" @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,7 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'org.apache.tomcat:annotations-api:6.0.53' + implementation 'io.grpc:grpc-okhttp:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/helloworld/build.gradle b/examples/android/helloworld/build.gradle index 67d25905bbc..6db6a9bced1 100644 --- a/examples/android/helloworld/build.gradle +++ b/examples/android/helloworld/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:7.4.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.4" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.5" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/helloworld/settings.gradle b/examples/android/helloworld/settings.gradle index e7b4def49cb..6208d70e838 100644 --- a/examples/android/helloworld/settings.gradle +++ b/examples/android/helloworld/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + include ':app' diff --git a/examples/android/routeguide/app/build.gradle b/examples/android/routeguide/app/build.gradle index 4a08f40e4ee..377cb417100 100644 --- a/examples/android/routeguide/app/build.gradle +++ b/examples/android/routeguide/app/build.gradle @@ -10,7 +10,7 @@ android { defaultConfig { applicationId "io.grpc.routeguideexample" - minSdkVersion 21 + minSdkVersion 23 targetSdkVersion 33 versionCode 1 versionName "1.0" @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,7 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'org.apache.tomcat:annotations-api:6.0.53' + implementation 'io.grpc:grpc-okhttp:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/routeguide/build.gradle b/examples/android/routeguide/build.gradle index fd058a5d68e..8fc1d293228 100644 --- a/examples/android/routeguide/build.gradle +++ b/examples/android/routeguide/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:7.4.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.4" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.5" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/routeguide/settings.gradle b/examples/android/routeguide/settings.gradle index e7b4def49cb..6208d70e838 100644 --- a/examples/android/routeguide/settings.gradle +++ b/examples/android/routeguide/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + include ':app' diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index 9f41994e3c2..b752bc4ffd3 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -33,7 +33,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -53,8 +53,7 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'org.apache.tomcat:annotations-api:6.0.53' + implementation 'io.grpc:grpc-okhttp:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/strictmode/build.gradle b/examples/android/strictmode/build.gradle index 67d25905bbc..6db6a9bced1 100644 --- a/examples/android/strictmode/build.gradle +++ b/examples/android/strictmode/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:7.4.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.4" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.5" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/strictmode/settings.gradle b/examples/android/strictmode/settings.gradle index e7b4def49cb..6208d70e838 100644 --- a/examples/android/strictmode/settings.gradle +++ b/examples/android/strictmode/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + include ':app' diff --git a/examples/build.gradle b/examples/build.gradle index c10b4eef46a..62e50d38861 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -1,14 +1,12 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -23,15 +21,14 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' def protocVersion = protobufVersion dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-services:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" // examples/advanced need this for JsonFormat implementation "com.google.protobuf:protobuf-java-util:${protobufVersion}" @@ -54,15 +51,14 @@ protobuf { } } -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} +// gRPC uses java.util.ServiceLoader, which reads class names from +// META-INF/services in jars. If you package your application as a "fat" jar +// that includes dependencies, you need to make sure the packaging tool +// concatenates duplicate files in META-INF/services. +// +// For the Shadow Gradle Plugin, use call mergeServiceFiles() within the +// shadowJar task. +// https://gradleup.com/shadow/configuration/merging/#merging-service-descriptor-files startScripts.enabled = false @@ -109,6 +105,7 @@ createStartScripts('io.grpc.examples.keepalive.KeepAliveClient') createStartScripts('io.grpc.examples.keepalive.KeepAliveServer') createStartScripts('io.grpc.examples.loadbalance.LoadBalanceClient') createStartScripts('io.grpc.examples.loadbalance.LoadBalanceServer') +createStartScripts('io.grpc.examples.manualflowcontrol.BidiBlockingClient') createStartScripts('io.grpc.examples.manualflowcontrol.ManualFlowControlClient') createStartScripts('io.grpc.examples.manualflowcontrol.ManualFlowControlServer') createStartScripts('io.grpc.examples.multiplex.MultiplexingServer') diff --git a/examples/example-alts/BUILD.bazel b/examples/example-alts/BUILD.bazel index 0404dcccf81..4d66accfc19 100644 --- a/examples/example-alts/BUILD.bazel +++ b/examples/example-alts/BUILD.bazel @@ -1,5 +1,8 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@io_grpc_grpc_java//:java_grpc_library.bzl", "java_grpc_library") +load("@rules_java//java:java_binary.bzl", "java_binary") +load("@rules_java//java:java_library.bzl", "java_library") proto_library( name = "helloworld_proto", diff --git a/examples/example-alts/example-alts/README.md b/examples/example-alts/README.md similarity index 100% rename from examples/example-alts/example-alts/README.md rename to examples/example-alts/README.md diff --git a/examples/example-alts/build.gradle b/examples/example-alts/build.gradle index 0d7d959de93..47268ab6510 100644 --- a/examples/example-alts/build.gradle +++ b/examples/example-alts/build.gradle @@ -1,15 +1,12 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } mavenCentral() mavenLocal() } @@ -24,13 +21,12 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { // grpc-alts transitively depends on grpc-netty-shaded, grpc-protobuf, and grpc-stub implementation "io.grpc:grpc-alts:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" } protobuf { @@ -43,16 +39,6 @@ protobuf { } } -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} - startScripts.enabled = false diff --git a/examples/example-alts/settings.gradle b/examples/example-alts/settings.gradle index 273558dd9cf..6bd0f0cdc2d 100644 --- a/examples/example-alts/settings.gradle +++ b/examples/example-alts/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-debug/build.gradle b/examples/example-debug/build.gradle index 5565747cb19..940543a3681 100644 --- a/examples/example-debug/build.gradle +++ b/examples/example-debug/build.gradle @@ -2,15 +2,13 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. id 'java' - id "com.google.protobuf" version "0.9.4" + id "com.google.protobuf" version "0.9.5" // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -25,14 +23,13 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-services:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" testImplementation 'junit:junit:4.13.2' diff --git a/examples/example-debug/pom.xml b/examples/example-debug/pom.xml index 064d989c04c..10734935ee6 100644 --- a/examples/example-debug/pom.xml +++ b/examples/example-debug/pom.xml @@ -6,14 +6,14 @@ jar - 1.68.0-SNAPSHOT + 1.81.0-SNAPSHOT example-debug https://github.com/grpc/grpc-java UTF-8 - 1.68.0-SNAPSHOT - 3.25.3 + 1.81.0-SNAPSHOT + 3.25.8 1.8 1.8 @@ -44,12 +44,6 @@ io.grpc grpc-stub - - org.apache.tomcat - annotations-api - 6.0.53 - provided - io.grpc grpc-netty-shaded diff --git a/examples/example-debug/settings.gradle b/examples/example-debug/settings.gradle index 3700c983b6c..48c08629ca9 100644 --- a/examples/example-debug/settings.gradle +++ b/examples/example-debug/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'example-debug' diff --git a/examples/example-debug/src/main/java/io/grpc/examples/debug/HelloWorldDebuggableClient.java b/examples/example-debug/src/main/java/io/grpc/examples/debug/HelloWorldDebuggableClient.java index 61391b60415..ef1340cf259 100644 --- a/examples/example-debug/src/main/java/io/grpc/examples/debug/HelloWorldDebuggableClient.java +++ b/examples/example-debug/src/main/java/io/grpc/examples/debug/HelloWorldDebuggableClient.java @@ -27,7 +27,7 @@ import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; import io.grpc.examples.helloworld.HelloRequest; -import io.grpc.protobuf.services.ProtoReflectionService; +import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.services.AdminInterface; import java.util.concurrent.TimeUnit; import java.util.logging.Level; diff --git a/examples/example-debug/src/main/java/io/grpc/examples/debug/HostnameDebuggableServer.java b/examples/example-debug/src/main/java/io/grpc/examples/debug/HostnameDebuggableServer.java index 89ffc39b599..5525ba91d9c 100644 --- a/examples/example-debug/src/main/java/io/grpc/examples/debug/HostnameDebuggableServer.java +++ b/examples/example-debug/src/main/java/io/grpc/examples/debug/HostnameDebuggableServer.java @@ -21,7 +21,7 @@ import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; -import io.grpc.protobuf.services.ProtoReflectionService; +import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.services.AdminInterface; import io.grpc.services.HealthStatusManager; import java.io.IOException; diff --git a/examples/example-dualstack/README.md b/examples/example-dualstack/README.md index 6c191661d1b..5a26886e259 100644 --- a/examples/example-dualstack/README.md +++ b/examples/example-dualstack/README.md @@ -2,7 +2,7 @@ The dualstack example uses a custom name resolver that provides both IPv4 and IPv6 localhost endpoints for each of 3 server instances. The client will first use the default name resolver and -load balancers which will only connect tot he first server. It will then use the +load balancers which will only connect to the first server. It will then use the custom name resolver with round robin to connect to each of the servers in turn. The 3 instances of the server will bind respectively to: both IPv4 and IPv6, IPv4 only, and IPv6 only. diff --git a/examples/example-dualstack/build.gradle b/examples/example-dualstack/build.gradle index 554b5f758d9..f2947c641cf 100644 --- a/examples/example-dualstack/build.gradle +++ b/examples/example-dualstack/build.gradle @@ -2,15 +2,13 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. id 'java' - id "com.google.protobuf" version "0.9.4" + id "com.google.protobuf" version "0.9.5" // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -25,15 +23,14 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-netty:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-services:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" } protobuf { diff --git a/examples/example-dualstack/pom.xml b/examples/example-dualstack/pom.xml index dfd650cdfa4..f5e720a9128 100644 --- a/examples/example-dualstack/pom.xml +++ b/examples/example-dualstack/pom.xml @@ -6,14 +6,14 @@ jar - 1.68.0-SNAPSHOT + 1.81.0-SNAPSHOT example-dualstack https://github.com/grpc/grpc-java UTF-8 - 1.68.0-SNAPSHOT - 3.25.3 + 1.81.0-SNAPSHOT + 3.25.8 1.8 1.8 @@ -48,12 +48,6 @@ io.grpc grpc-netty - - org.apache.tomcat - annotations-api - 6.0.53 - provided - io.grpc grpc-netty-shaded diff --git a/examples/example-dualstack/settings.gradle b/examples/example-dualstack/settings.gradle index 0aae8f7304e..160d5134334 100644 --- a/examples/example-dualstack/settings.gradle +++ b/examples/example-dualstack/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-gauth/BUILD.bazel b/examples/example-gauth/BUILD.bazel index edc4a291e27..033c51f8856 100644 --- a/examples/example-gauth/BUILD.bazel +++ b/examples/example-gauth/BUILD.bazel @@ -1,4 +1,5 @@ -load("@io_grpc_grpc_java//:java_grpc_library.bzl", "java_grpc_library") +load("@rules_java//java:java_binary.bzl", "java_binary") +load("@rules_java//java:java_library.bzl", "java_library") java_library( name = "example-gauth", diff --git a/examples/example-gauth/README.md b/examples/example-gauth/README.md index 622c14cb57b..b49d346a9be 100644 --- a/examples/example-gauth/README.md +++ b/examples/example-gauth/README.md @@ -43,13 +43,13 @@ gcloud pubsub topics create Topic1 5. You will now need to set up [authentication](https://cloud.google.com/docs/authentication/) and a [service account](https://cloud.google.com/docs/authentication/#service_accounts) in order to access Pub/Sub via gRPC APIs as described [here](https://cloud.google.com/iam/docs/creating-managing-service-accounts). -Assign the [role](https://cloud.google.com/iam/docs/granting-roles-to-service-accounts) `Project -> Owner` +(**Note:** This step is unnecessary on Google platforms (Google App Engine / Google Cloud Shell / Google Compute Engine) as it will +automatically use the in-built Google credentials). Assign the [role](https://cloud.google.com/iam/docs/granting-roles-to-service-accounts) `Project -> Owner` and for Key type select JSON. Once you click `Create`, a JSON file containing your key is downloaded to your computer. Note down the path of this file or copy this file to the computer and file system where you will be running the example application as described later. Assume this JSON file is available at -`/path/to/JSON/file`. You can also use the `gcloud` shell commands to -[create the service account](https://cloud.google.com/iam/docs/creating-managing-service-accounts#iam-service-accounts-create-gcloud) -and [the JSON file](https://cloud.google.com/iam/docs/creating-managing-service-account-keys#iam-service-account-keys-create-gcloud). +`/path/to/JSON/file` Set the value of the environment variable GOOGLE_APPLICATION_CREDENTIALS to this file path. You can also use the `gcloud` shell commands to +[create the service account](https://cloud.google.com/iam/docs/creating-managing-service-accounts#iam-service-accounts-create-gcloud). #### To build the examples @@ -62,19 +62,18 @@ $ ../gradlew installDist #### How to run the example: -`google-auth-client` requires two command line arguments for the location of the JSON file and the project ID: +`google-auth-client` requires one command line argument for the project ID: ```text -USAGE: GoogleAuthClient +USAGE: GoogleAuthClient ``` -The first argument is the location of the JSON file you created in step 5 above. -The second argument is the project ID in the form "projects/xyz123" where "xyz123" is +The first argument is the project ID in the form "projects/xyz123" where "xyz123" is the project ID of the project you created (or used) in step 2 above. ```bash # Run the client -./build/install/example-gauth/bin/google-auth-client /path/to/JSON/file projects/xyz123 +./build/install/example-gauth/bin/google-auth-client projects/xyz123 ``` That's it! The client will show the list of Pub/Sub topics for the project as follows: @@ -93,7 +92,7 @@ the project ID of the project you created (or used) in step 2 above. ``` $ mvn verify $ # Run the client - $ mvn exec:java -Dexec.mainClass=io.grpc.examples.googleAuth.GoogleAuthClient -Dexec.args="/path/to/JSON/file projects/xyz123" + $ mvn exec:java -Dexec.mainClass=io.grpc.examples.googleAuth.GoogleAuthClient -Dexec.args="projects/xyz123" ``` ## Bazel @@ -101,5 +100,5 @@ the project ID of the project you created (or used) in step 2 above. ``` $ bazel build :google-auth-client $ # Run the client - $ ../bazel-bin/google-auth-client /path/to/JSON/file projects/xyz123 + $ ../bazel-bin/google-auth-client projects/xyz123 ``` \ No newline at end of file diff --git a/examples/example-gauth/build.gradle b/examples/example-gauth/build.gradle index 47e812fde15..489197e5f20 100644 --- a/examples/example-gauth/build.gradle +++ b/examples/example-gauth/build.gradle @@ -1,15 +1,12 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } mavenCentral() mavenLocal() } @@ -24,8 +21,8 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' def protocVersion = protobufVersion @@ -33,8 +30,7 @@ dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-auth:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" - implementation "com.google.auth:google-auth-library-oauth2-http:1.23.0" + implementation "com.google.auth:google-auth-library-oauth2-http:1.42.1" implementation "com.google.api.grpc:grpc-google-cloud-pubsub-v1:0.1.24" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" } @@ -49,16 +45,6 @@ protobuf { } } -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} - startScripts.enabled = false task googleAuthClient(type: CreateStartScripts) { diff --git a/examples/example-gauth/pom.xml b/examples/example-gauth/pom.xml index d2cba1a7959..9fb854629b4 100644 --- a/examples/example-gauth/pom.xml +++ b/examples/example-gauth/pom.xml @@ -6,14 +6,14 @@ jar - 1.68.0-SNAPSHOT + 1.81.0-SNAPSHOT example-gauth https://github.com/grpc/grpc-java UTF-8 - 1.68.0-SNAPSHOT - 3.25.3 + 1.81.0-SNAPSHOT + 3.25.8 1.8 1.8 @@ -28,6 +28,11 @@ pom import + + com.google.code.gson + gson + 2.13.2 + @@ -49,12 +54,6 @@ io.grpc grpc-auth - - org.apache.tomcat - annotations-api - 6.0.53 - provided - io.grpc grpc-testing @@ -63,7 +62,7 @@ com.google.auth google-auth-library-oauth2-http - 1.23.0 + 1.40.0 com.google.api.grpc diff --git a/examples/example-gauth/settings.gradle b/examples/example-gauth/settings.gradle index 273558dd9cf..6bd0f0cdc2d 100644 --- a/examples/example-gauth/settings.gradle +++ b/examples/example-gauth/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-gauth/src/main/java/io/grpc/examples/googleAuth/GoogleAuthClient.java b/examples/example-gauth/src/main/java/io/grpc/examples/googleAuth/GoogleAuthClient.java index 4d3dd044376..eb0d9feedfc 100644 --- a/examples/example-gauth/src/main/java/io/grpc/examples/googleAuth/GoogleAuthClient.java +++ b/examples/example-gauth/src/main/java/io/grpc/examples/googleAuth/GoogleAuthClient.java @@ -33,7 +33,7 @@ /** * Example to illustrate use of Google credentials as described in - * @see Google Auth Example README + * @see Google Auth Example README * * Also @see Google Cloud Pubsub via gRPC */ @@ -52,7 +52,7 @@ public class GoogleAuthClient { * * @param host host to connect to - typically "pubsub.googleapis.com" * @param port port to connect to - typically 443 - the TLS port - * @param callCredentials the Google call credentials created from a JSON file + * @param callCredentials the Google call credentials */ public GoogleAuthClient(String host, int port, CallCredentials callCredentials) { // Google API invocation requires a secure channel. Channels are secure by default (SSL/TLS) @@ -63,7 +63,7 @@ public GoogleAuthClient(String host, int port, CallCredentials callCredentials) * Construct our gRPC client that connects to the pubsub server using an existing channel. * * @param channel channel that has been built already - * @param callCredentials the Google call credentials created from a JSON file + * @param callCredentials the Google call credentials */ GoogleAuthClient(ManagedChannel channel, CallCredentials callCredentials) { this.channel = channel; @@ -101,32 +101,30 @@ public void getTopics(String projectID) { /** * The app requires 2 arguments as described in - * @see Google Auth Example README + * @see Google Auth Example README * - * arg0 = location of the JSON file for the service account you created in the GCP console - * arg1 = project name in the form "projects/balmy-cirrus-225307" where "balmy-cirrus-225307" is + * arg0 = project name in the form "projects/balmy-cirrus-225307" where "balmy-cirrus-225307" is * the project ID for the project you created. * + * On non-Google platforms, the GOOGLE_APPLICATION_CREDENTIALS env variable should be set to the + * location of the JSON file for the service account you created in the GCP console. */ public static void main(String[] args) throws Exception { - if (args.length < 2) { - logger.severe("Usage: please pass 2 arguments:\n" + - "arg0 = location of the JSON file for the service account you created in the GCP console\n" + - "arg1 = project name in the form \"projects/xyz\" where \"xyz\" is the project ID of the project you created.\n"); + if (args.length < 1) { + logger.severe("Usage: please pass 1 argument:\n" + + "arg0 = project name in the form \"projects/xyz\" where \"xyz\" is the project ID of the project you created.\n"); System.exit(1); } - GoogleCredentials credentials = GoogleCredentials.fromStream(new FileInputStream(args[0])); + GoogleCredentials credentials = GoogleCredentials.getApplicationDefault(); // We need to create appropriate scope as per https://cloud.google.com/storage/docs/authentication#oauth-scopes credentials = credentials.createScoped(Arrays.asList("https://www.googleapis.com/auth/cloud-platform")); - // credentials must be refreshed before the access token is available - credentials.refreshAccessToken(); GoogleAuthClient client = new GoogleAuthClient("pubsub.googleapis.com", 443, MoreCallCredentials.from(credentials)); try { - client.getTopics(args[1]); + client.getTopics(args[0]); } finally { client.shutdown(); } diff --git a/examples/example-gcp-csm-observability/build.gradle b/examples/example-gcp-csm-observability/build.gradle index a392018ba25..63c6d20125d 100644 --- a/examples/example-gcp-csm-observability/build.gradle +++ b/examples/example-gcp-csm-observability/build.gradle @@ -1,16 +1,13 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' id 'java' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } mavenCentral() mavenLocal() } @@ -25,19 +22,19 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.3' -def openTelemetryVersion = '1.40.0' -def openTelemetryPrometheusVersion = '1.40.0-alpha' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' +def openTelemetryVersion = '1.56.0' +def openTelemetryPrometheusVersion = '1.56.0-alpha' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-gcp-csm-observability:${grpcVersion}" + implementation "io.grpc:grpc-xds:${grpcVersion}" implementation "io.opentelemetry:opentelemetry-sdk:${openTelemetryVersion}" implementation "io.opentelemetry:opentelemetry-sdk-metrics:${openTelemetryVersion}" implementation "io.opentelemetry:opentelemetry-exporter-prometheus:${openTelemetryPrometheusVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-xds:${grpcVersion}" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" } diff --git a/examples/example-gcp-csm-observability/settings.gradle b/examples/example-gcp-csm-observability/settings.gradle index 6b7615117d6..44e6f340ede 100644 --- a/examples/example-gcp-csm-observability/settings.gradle +++ b/examples/example-gcp-csm-observability/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'example-gcp-csm-observability' diff --git a/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityClient.java b/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityClient.java index 7387c18da96..dd0ab7eb546 100644 --- a/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityClient.java +++ b/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityClient.java @@ -25,6 +25,7 @@ import io.grpc.examples.helloworld.HelloReply; import io.grpc.examples.helloworld.HelloRequest; import io.grpc.gcp.csm.observability.CsmObservability; +import io.grpc.xds.XdsChannelCredentials; import io.opentelemetry.exporter.prometheus.PrometheusHttpServer; import io.opentelemetry.sdk.OpenTelemetrySdk; import io.opentelemetry.sdk.metrics.SdkMeterProvider; @@ -127,8 +128,10 @@ public void run() { observability.registerGlobal(); // Create a communication channel to the server, known as a Channel. - ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) - .build(); + ManagedChannel channel = + Grpc.newChannelBuilder( + target, XdsChannelCredentials.create(InsecureChannelCredentials.create())) + .build(); CsmObservabilityClient client = new CsmObservabilityClient(channel); try { diff --git a/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityServer.java b/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityServer.java index 78df71b65a9..589753b1a4c 100644 --- a/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityServer.java +++ b/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityServer.java @@ -24,6 +24,8 @@ import io.grpc.examples.helloworld.HelloRequest; import io.grpc.gcp.csm.observability.CsmObservability; import io.grpc.stub.StreamObserver; +import io.grpc.xds.XdsServerBuilder; +import io.grpc.xds.XdsServerCredentials; import io.opentelemetry.exporter.prometheus.PrometheusHttpServer; import io.opentelemetry.sdk.OpenTelemetrySdk; import io.opentelemetry.sdk.metrics.SdkMeterProvider; @@ -40,10 +42,12 @@ public class CsmObservabilityServer { private Server server; private void start(int port) throws IOException { - server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) - .addService(new GreeterImpl()) - .build() - .start(); + server = + XdsServerBuilder.forPort( + port, XdsServerCredentials.create(InsecureServerCredentials.create())) + .addService(new GreeterImpl()) + .build() + .start(); logger.info("Server started, listening on " + port); } diff --git a/examples/example-gcp-observability/build.gradle b/examples/example-gcp-observability/build.gradle index dcb8d420020..a41e7cdd629 100644 --- a/examples/example-gcp-observability/build.gradle +++ b/examples/example-gcp-observability/build.gradle @@ -1,16 +1,13 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' id 'java' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } mavenCentral() mavenLocal() } @@ -25,14 +22,13 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-gcp-observability:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" } diff --git a/examples/example-gcp-observability/settings.gradle b/examples/example-gcp-observability/settings.gradle index 1e4ba3812eb..39efc20a459 100644 --- a/examples/example-gcp-observability/settings.gradle +++ b/examples/example-gcp-observability/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'example-gcp-observability' diff --git a/examples/example-hostname/BUILD.bazel b/examples/example-hostname/BUILD.bazel index 8b76f790983..d5bd3aba94c 100644 --- a/examples/example-hostname/BUILD.bazel +++ b/examples/example-hostname/BUILD.bazel @@ -1,5 +1,8 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@io_grpc_grpc_java//:java_grpc_library.bzl", "java_grpc_library") +load("@rules_java//java:java_binary.bzl", "java_binary") +load("@rules_java//java:java_library.bzl", "java_library") proto_library( name = "helloworld_proto", diff --git a/examples/example-hostname/build.gradle b/examples/example-hostname/build.gradle index df8b0fde121..6117b8c32a1 100644 --- a/examples/example-hostname/build.gradle +++ b/examples/example-hostname/build.gradle @@ -2,13 +2,11 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. id 'java' - id "com.google.protobuf" version "0.9.4" - id 'com.google.cloud.tools.jib' version '3.4.3' // For releasing to Docker Hub + id "com.google.protobuf" version "0.9.5" + id 'com.google.cloud.tools.jib' version '3.4.4' // For releasing to Docker Hub } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -23,14 +21,13 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-services:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" testImplementation 'junit:junit:4.13.2' diff --git a/examples/example-hostname/pom.xml b/examples/example-hostname/pom.xml index c6d39887bac..ed90d481587 100644 --- a/examples/example-hostname/pom.xml +++ b/examples/example-hostname/pom.xml @@ -6,14 +6,14 @@ jar - 1.68.0-SNAPSHOT + 1.81.0-SNAPSHOT example-hostname https://github.com/grpc/grpc-java UTF-8 - 1.68.0-SNAPSHOT - 3.25.3 + 1.81.0-SNAPSHOT + 3.25.8 1.8 1.8 @@ -44,12 +44,6 @@ io.grpc grpc-stub - - org.apache.tomcat - annotations-api - 6.0.53 - provided - io.grpc grpc-netty-shaded diff --git a/examples/example-hostname/settings.gradle b/examples/example-hostname/settings.gradle index aa159eb0946..5bd641b3fc1 100644 --- a/examples/example-hostname/settings.gradle +++ b/examples/example-hostname/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'hostname' diff --git a/examples/example-hostname/src/main/java/io/grpc/examples/hostname/HostnameServer.java b/examples/example-hostname/src/main/java/io/grpc/examples/hostname/HostnameServer.java index 3c63296d7fa..7baa2d4733d 100644 --- a/examples/example-hostname/src/main/java/io/grpc/examples/hostname/HostnameServer.java +++ b/examples/example-hostname/src/main/java/io/grpc/examples/hostname/HostnameServer.java @@ -21,7 +21,7 @@ import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; -import io.grpc.protobuf.services.ProtoReflectionService; +import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.services.HealthStatusManager; import java.io.IOException; import java.util.concurrent.TimeUnit; @@ -53,7 +53,7 @@ public static void main(String[] args) throws IOException, InterruptedException HealthStatusManager health = new HealthStatusManager(); final Server server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) .addService(new HostnameGreeter(hostname)) - .addService(ProtoReflectionService.newInstance()) + .addService(ProtoReflectionServiceV1.newInstance()) .addService(health.getHealthService()) .build() .start(); @@ -64,17 +64,17 @@ public void run() { // Start graceful shutdown server.shutdown(); try { - // Wait for RPCs to complete processing - if (!server.awaitTermination(30, TimeUnit.SECONDS)) { - // That was plenty of time. Let's cancel the remaining RPCs - server.shutdownNow(); - // shutdownNow isn't instantaneous, so give a bit of time to clean resources up - // gracefully. Normally this will be well under a second. - server.awaitTermination(5, TimeUnit.SECONDS); - } + // Wait up to 30 seconds for RPCs to complete processing. + server.awaitTermination(30, TimeUnit.SECONDS); } catch (InterruptedException ex) { - server.shutdownNow(); + Thread.currentThread().interrupt(); } + // Cancel any remaining RPCs. If awaitTermination() returned true above, then there are no + // RPCs and the server is already terminated. But it is safe to call even when terminated. + server.shutdownNow(); + // shutdownNow isn't instantaneous, so you want an additional awaitTermination() to give + // time to clean resources up gracefully. Normally it will return in well under a second. In + // this example, the server.awaitTermination() in main() provides that delay. } }); // This would normally be tied to the service's dependencies. For example, if HostnameGreeter diff --git a/examples/example-jwt-auth/build.gradle b/examples/example-jwt-auth/build.gradle index f996282bbb0..5614a72742c 100644 --- a/examples/example-jwt-auth/build.gradle +++ b/examples/example-jwt-auth/build.gradle @@ -1,15 +1,13 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } + mavenCentral() mavenLocal() } @@ -23,8 +21,8 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' def protocVersion = protobufVersion dependencies { @@ -33,8 +31,6 @@ dependencies { implementation "io.jsonwebtoken:jjwt:0.9.1" implementation "javax.xml.bind:jaxb-api:2.3.1" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" - runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" testImplementation "io.grpc:grpc-testing:${grpcVersion}" @@ -53,16 +49,6 @@ protobuf { } } -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} - startScripts.enabled = false task hellowWorldJwtAuthServer(type: CreateStartScripts) { diff --git a/examples/example-jwt-auth/pom.xml b/examples/example-jwt-auth/pom.xml index c84f9893980..7befaf500c5 100644 --- a/examples/example-jwt-auth/pom.xml +++ b/examples/example-jwt-auth/pom.xml @@ -7,15 +7,15 @@ jar - 1.68.0-SNAPSHOT + 1.81.0-SNAPSHOT example-jwt-auth https://github.com/grpc/grpc-java UTF-8 - 1.68.0-SNAPSHOT - 3.25.3 - 3.25.3 + 1.81.0-SNAPSHOT + 3.25.8 + 3.25.8 1.8 1.8 @@ -57,12 +57,6 @@ jaxb-api 2.3.1 - - org.apache.tomcat - annotations-api - 6.0.53 - provided - io.grpc grpc-testing diff --git a/examples/example-jwt-auth/settings.gradle b/examples/example-jwt-auth/settings.gradle index 273558dd9cf..6bd0f0cdc2d 100644 --- a/examples/example-jwt-auth/settings.gradle +++ b/examples/example-jwt-auth/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-oauth/build.gradle b/examples/example-oauth/build.gradle index 7f600c2bc53..07e51217622 100644 --- a/examples/example-oauth/build.gradle +++ b/examples/example-oauth/build.gradle @@ -1,15 +1,13 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } + mavenCentral() mavenLocal() } @@ -23,17 +21,15 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' def protocVersion = protobufVersion dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-auth:${grpcVersion}" - implementation "com.google.auth:google-auth-library-oauth2-http:1.23.0" - - compileOnly "org.apache.tomcat:annotations-api:6.0.53" + implementation "com.google.auth:google-auth-library-oauth2-http:1.42.1" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" @@ -53,16 +49,6 @@ protobuf { } } -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} - startScripts.enabled = false task hellowWorldOauthServer(type: CreateStartScripts) { diff --git a/examples/example-oauth/pom.xml b/examples/example-oauth/pom.xml index fa2eaa41e36..9ce20f2f684 100644 --- a/examples/example-oauth/pom.xml +++ b/examples/example-oauth/pom.xml @@ -7,15 +7,15 @@ jar - 1.68.0-SNAPSHOT + 1.81.0-SNAPSHOT example-oauth https://github.com/grpc/grpc-java UTF-8 - 1.68.0-SNAPSHOT - 3.25.3 - 3.25.3 + 1.81.0-SNAPSHOT + 3.25.8 + 3.25.8 1.8 1.8 @@ -30,6 +30,11 @@ pom import + + com.google.code.gson + gson + 2.13.2 + @@ -50,23 +55,11 @@ io.grpc grpc-auth - - - com.google.auth - google-auth-library-credentials - - com.google.auth google-auth-library-oauth2-http - 1.23.0 - - - org.apache.tomcat - annotations-api - 6.0.53 - provided + 1.40.0 io.grpc diff --git a/examples/example-oauth/settings.gradle b/examples/example-oauth/settings.gradle index 273558dd9cf..6bd0f0cdc2d 100644 --- a/examples/example-oauth/settings.gradle +++ b/examples/example-oauth/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-opentelemetry/build.gradle b/examples/example-opentelemetry/build.gradle index 21264ffcc17..a24900c0fe5 100644 --- a/examples/example-opentelemetry/build.gradle +++ b/examples/example-opentelemetry/build.gradle @@ -1,15 +1,12 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } mavenCentral() mavenLocal() } @@ -24,10 +21,10 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.3' -def openTelemetryVersion = '1.40.0' -def openTelemetryPrometheusVersion = '1.40.0-alpha' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' +def openTelemetryVersion = '1.56.0' +def openTelemetryPrometheusVersion = '1.56.0-alpha' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" @@ -37,7 +34,6 @@ dependencies { implementation "io.opentelemetry:opentelemetry-sdk-metrics:${openTelemetryVersion}" implementation "io.opentelemetry:opentelemetry-exporter-logging:${openTelemetryVersion}" implementation "io.opentelemetry:opentelemetry-exporter-prometheus:${openTelemetryPrometheusVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" } diff --git a/examples/example-opentelemetry/settings.gradle b/examples/example-opentelemetry/settings.gradle index ff7ea3fc2be..26e3bea044b 100644 --- a/examples/example-opentelemetry/settings.gradle +++ b/examples/example-opentelemetry/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'example-opentelemetry' diff --git a/examples/example-orca/build.gradle b/examples/example-orca/build.gradle index d087a532aff..674c4bdf2f7 100644 --- a/examples/example-orca/build.gradle +++ b/examples/example-orca/build.gradle @@ -1,14 +1,12 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' id 'java' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -18,16 +16,14 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-services:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-xds:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" - } protobuf { diff --git a/examples/example-orca/settings.gradle b/examples/example-orca/settings.gradle index 3c62dc663ce..12536c0ca8d 100644 --- a/examples/example-orca/settings.gradle +++ b/examples/example-orca/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'example-orca' diff --git a/examples/example-reflection/README.md b/examples/example-reflection/README.md index 801a27343db..4bc30e84b3b 100644 --- a/examples/example-reflection/README.md +++ b/examples/example-reflection/README.md @@ -1,7 +1,7 @@ gRPC Reflection Example ================ -The reflection example has a Hello World server with `ProtoReflectionService` registered. +The reflection example has a Hello World server with `ProtoReflectionServiceV1` registered. ### Build the example diff --git a/examples/example-reflection/build.gradle b/examples/example-reflection/build.gradle index d7d5c50b7e6..aa870967135 100644 --- a/examples/example-reflection/build.gradle +++ b/examples/example-reflection/build.gradle @@ -1,14 +1,12 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' id 'java' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -18,16 +16,14 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-services:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-netty-shaded:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" - } protobuf { diff --git a/examples/example-reflection/settings.gradle b/examples/example-reflection/settings.gradle index dccb973085e..28e44b77905 100644 --- a/examples/example-reflection/settings.gradle +++ b/examples/example-reflection/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'example-reflection' diff --git a/examples/example-reflection/src/main/java/io/grpc/examples/reflection/ReflectionServer.java b/examples/example-reflection/src/main/java/io/grpc/examples/reflection/ReflectionServer.java index ad702247ba7..8406317aad6 100644 --- a/examples/example-reflection/src/main/java/io/grpc/examples/reflection/ReflectionServer.java +++ b/examples/example-reflection/src/main/java/io/grpc/examples/reflection/ReflectionServer.java @@ -7,7 +7,7 @@ import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; import io.grpc.examples.helloworld.HelloRequest; -import io.grpc.protobuf.services.ProtoReflectionService; +import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.stub.StreamObserver; import java.io.IOException; import java.util.concurrent.TimeUnit; @@ -26,7 +26,7 @@ private void start() throws IOException { int port = 50051; server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) .addService(new GreeterImpl()) - .addService(ProtoReflectionService.newInstance()) // add reflection service + .addService(ProtoReflectionServiceV1.newInstance()) // add reflection service .build() .start(); logger.info("Server started, listening on " + port); diff --git a/examples/example-servlet/build.gradle b/examples/example-servlet/build.gradle index 995e2d0979b..7f23c83e0d9 100644 --- a/examples/example-servlet/build.gradle +++ b/examples/example-servlet/build.gradle @@ -1,13 +1,12 @@ plugins { - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' id 'war' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } + mavenCentral() mavenLocal() } @@ -16,16 +15,15 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}", "io.grpc:grpc-servlet:${grpcVersion}", "io.grpc:grpc-stub:${grpcVersion}" - compileOnly "javax.servlet:javax.servlet-api:4.0.1", - "org.apache.tomcat:annotations-api:6.0.53" + compileOnly "javax.servlet:javax.servlet-api:4.0.1" } protobuf { @@ -35,13 +33,3 @@ protobuf { all()*.plugins { grpc {} } } } - -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} diff --git a/examples/example-servlet/settings.gradle b/examples/example-servlet/settings.gradle index 273558dd9cf..6bd0f0cdc2d 100644 --- a/examples/example-servlet/settings.gradle +++ b/examples/example-servlet/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-tls/BUILD.bazel b/examples/example-tls/BUILD.bazel index 81913836766..cb46ef5bb30 100644 --- a/examples/example-tls/BUILD.bazel +++ b/examples/example-tls/BUILD.bazel @@ -1,5 +1,8 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@io_grpc_grpc_java//:java_grpc_library.bzl", "java_grpc_library") +load("@rules_java//java:java_binary.bzl", "java_binary") +load("@rules_java//java:java_library.bzl", "java_library") proto_library( name = "helloworld_proto", diff --git a/examples/example-tls/build.gradle b/examples/example-tls/build.gradle index 8aad6b62bcb..456cb8b4f73 100644 --- a/examples/example-tls/build.gradle +++ b/examples/example-tls/build.gradle @@ -1,15 +1,12 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } mavenCentral() mavenLocal() } @@ -24,13 +21,12 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" } @@ -44,16 +40,6 @@ protobuf { } } -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} - startScripts.enabled = false task helloWorldTlsServer(type: CreateStartScripts) { diff --git a/examples/example-tls/pom.xml b/examples/example-tls/pom.xml index e1d569a628c..ff9d01253f5 100644 --- a/examples/example-tls/pom.xml +++ b/examples/example-tls/pom.xml @@ -6,14 +6,14 @@ jar - 1.68.0-SNAPSHOT + 1.81.0-SNAPSHOT example-tls https://github.com/grpc/grpc-java UTF-8 - 1.68.0-SNAPSHOT - 3.25.3 + 1.81.0-SNAPSHOT + 3.25.8 1.8 1.8 @@ -40,12 +40,6 @@ io.grpc grpc-stub - - org.apache.tomcat - annotations-api - 6.0.53 - provided - io.grpc grpc-netty-shaded diff --git a/examples/example-tls/settings.gradle b/examples/example-tls/settings.gradle index 273558dd9cf..6bd0f0cdc2d 100644 --- a/examples/example-tls/settings.gradle +++ b/examples/example-tls/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-xds/build.gradle b/examples/example-xds/build.gradle index 8339db77e0c..e8b3f3dd395 100644 --- a/examples/example-xds/build.gradle +++ b/examples/example-xds/build.gradle @@ -1,14 +1,12 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' id 'java' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -23,15 +21,14 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.3' +def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-services:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-xds:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" } diff --git a/examples/example-xds/settings.gradle b/examples/example-xds/settings.gradle index 878f1f23ae3..4197fa6760d 100644 --- a/examples/example-xds/settings.gradle +++ b/examples/example-xds/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'example-xds' diff --git a/examples/example-xds/src/main/java/io/grpc/examples/helloworldxds/XdsHelloWorldServer.java b/examples/example-xds/src/main/java/io/grpc/examples/helloworldxds/XdsHelloWorldServer.java index 93317dda23e..c7c67f8d681 100644 --- a/examples/example-xds/src/main/java/io/grpc/examples/helloworldxds/XdsHelloWorldServer.java +++ b/examples/example-xds/src/main/java/io/grpc/examples/helloworldxds/XdsHelloWorldServer.java @@ -20,7 +20,7 @@ import io.grpc.Server; import io.grpc.ServerCredentials; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; -import io.grpc.protobuf.services.ProtoReflectionService; +import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.services.HealthStatusManager; import io.grpc.xds.XdsServerBuilder; import io.grpc.xds.XdsServerCredentials; @@ -66,7 +66,7 @@ public static void main(String[] args) throws IOException, InterruptedException final HealthStatusManager health = new HealthStatusManager(); final Server server = XdsServerBuilder.forPort(port, credentials) .addService(new HostnameGreeter(hostname)) - .addService(ProtoReflectionService.newInstance()) // convenient for command line tools + .addService(ProtoReflectionServiceV1.newInstance()) // convenient for command line tools .addService(health.getHealthService()) // allow management servers to monitor health .build() .start(); diff --git a/examples/gradle/wrapper/gradle-wrapper.properties b/examples/gradle/wrapper/gradle-wrapper.properties index 0d1842103b1..1e2fbf0d458 100644 --- a/examples/gradle/wrapper/gradle-wrapper.properties +++ b/examples/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.8-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.10.2-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/examples/maven-assembly-jar-with-dependencies.xml b/examples/maven-assembly-jar-with-dependencies.xml new file mode 100644 index 00000000000..6c8abbfe7e8 --- /dev/null +++ b/examples/maven-assembly-jar-with-dependencies.xml @@ -0,0 +1,27 @@ + + + jar-with-dependencies + + jar + + false + + + / + true + true + runtime + + + + + metaInf-services + + + diff --git a/examples/pom.xml b/examples/pom.xml index 247df4a73ce..5375b930b3b 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -6,15 +6,15 @@ jar - 1.68.0-SNAPSHOT + 1.81.0-SNAPSHOT examples https://github.com/grpc/grpc-java UTF-8 - 1.68.0-SNAPSHOT - 3.25.3 - 3.25.3 + 1.81.0-SNAPSHOT + 3.25.8 + 3.25.8 1.8 1.8 @@ -58,13 +58,7 @@ com.google.j2objc j2objc-annotations - 3.0.0 - - - org.apache.tomcat - annotations-api - 6.0.53 - provided + 3.1 io.grpc @@ -130,6 +124,35 @@ + + + + + + + + maven-assembly-plugin + 3.7.1 + + ${project.basedir}/maven-assembly-jar-with-dependencies.xml + + + + make-assembly + package + + single + + + + diff --git a/examples/settings.gradle b/examples/settings.gradle index 0473750a54f..4d39e8b45ba 100644 --- a/examples/settings.gradle +++ b/examples/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/src/main/java/io/grpc/examples/advanced/README.md b/examples/src/main/java/io/grpc/examples/advanced/README.md new file mode 100644 index 00000000000..f5b5c6cc7fc --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/advanced/README.md @@ -0,0 +1,16 @@ +gRPC JSON Serialization Example +===================== + +gRPC is a modern high-performance framework for building Remote Procedure Call (RPC) systems. +It commonly uses Protocol Buffers (Protobuf) as its serialization format, which is compact and efficient. +However, gRPC can also support JSON serialization when needed, typically for interoperability with +systems or clients that do not use Protobuf. +This is an advanced example of how to swap out the serialization logic, Normal users do not need to do this. +This code is not intended to be a production-ready implementation, since JSON encoding is slow. +Additionally, JSON serialization as implemented may be not resilient to malicious input. + +This advanced example uses Marshaller for JSON which marshals in the Protobuf 3 format described here +https://developers.google.com/protocol-buffers/docs/proto3#json + +If you are considering implementing your own serialization logic, contact the grpc team at +https://groups.google.com/forum/#!forum/grpc-io diff --git a/examples/src/main/java/io/grpc/examples/cancellation/README.md b/examples/src/main/java/io/grpc/examples/cancellation/README.md new file mode 100644 index 00000000000..6b11a17c517 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/cancellation/README.md @@ -0,0 +1,18 @@ +gRPC Cancellation Example +===================== + +When a gRPC client is no longer interested in the result of an RPC call, +it may cancel to signal this discontinuation of interest to the server. + +Any abort of an ongoing RPC is considered "cancellation" of that RPC. +The common causes of cancellation are the client explicitly cancelling, the deadline expires, and I/O failures. +The service is not informed the reason for the cancellation. + +There are two APIs for services to be notified of RPC cancellation: io.grpc.Context and ServerCallStreamObserver + +Context listeners are called on a different thread, so need to be thread-safe. +The ServerCallStreamObserver cancellation callback is called like other StreamObserver callbacks, +so the application may not need thread-safe handling. +Both APIs have thread-safe isCancelled() polling methods. + +Refer the gRPC documentation for details on Cancellation of RPCs https://grpc.io/docs/guides/cancellation/ diff --git a/examples/src/main/java/io/grpc/examples/customloadbalance/README.md b/examples/src/main/java/io/grpc/examples/customloadbalance/README.md new file mode 100644 index 00000000000..20dbccb81ac --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/customloadbalance/README.md @@ -0,0 +1,19 @@ +gRPC Custom Load Balance Example +===================== + +One of the key features of gRPC is load balancing, which allows requests from clients to be distributed across multiple servers. +This helps prevent any one server from becoming overloaded and allows the system to scale up by adding more servers. + +A gRPC load balancing policy is given a list of server IP addresses by the name resolver. +The policy is responsible for maintaining connections (subchannels) to the servers and picking a connection to use when an RPC is sent. + +This example gives the details about how we can implement our own custom load balance policy, If the built-in policies does not meet your requirements +and follow below steps for the same. + + - Register your implementation in the load balancer registry so that it can be referred to from the service config + - Parse the JSON configuration object of your implementation. This allows your load balancer to be configured in the service config with any arbitrary JSON you choose to support + - Manage what backends to maintain a connection with + - Implement a picker that will choose which backend to connect to when an RPC is made. Note that this needs to be a fast operation as it is on the RPC call path + - To enable your load balancer, configure it in your service config + +Refer the gRPC documentation for more details https://grpc.io/docs/guides/custom-load-balancing/ diff --git a/examples/src/main/java/io/grpc/examples/customloadbalance/ShufflingPickFirstLoadBalancer.java b/examples/src/main/java/io/grpc/examples/customloadbalance/ShufflingPickFirstLoadBalancer.java index 4cf09170c8d..4715b551524 100644 --- a/examples/src/main/java/io/grpc/examples/customloadbalance/ShufflingPickFirstLoadBalancer.java +++ b/examples/src/main/java/io/grpc/examples/customloadbalance/ShufflingPickFirstLoadBalancer.java @@ -92,7 +92,7 @@ public void onSubchannelState(ConnectivityStateInfo stateInfo) { }); this.subchannel = subchannel; - helper.updateBalancingState(CONNECTING, new Picker(PickResult.withNoResult())); + helper.updateBalancingState(CONNECTING, new FixedResultPicker(PickResult.withNoResult())); subchannel.requestConnection(); } else { subchannel.updateAddresses(servers); @@ -107,7 +107,8 @@ public void handleNameResolutionError(Status error) { subchannel.shutdown(); subchannel = null; } - helper.updateBalancingState(TRANSIENT_FAILURE, new Picker(PickResult.withError(error))); + helper.updateBalancingState( + TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); } private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { @@ -122,16 +123,16 @@ private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo SubchannelPicker picker; switch (currentState) { case IDLE: - picker = new RequestConnectionPicker(subchannel); + picker = new RequestConnectionPicker(); break; case CONNECTING: - picker = new Picker(PickResult.withNoResult()); + picker = new FixedResultPicker(PickResult.withNoResult()); break; case READY: - picker = new Picker(PickResult.withSubchannel(subchannel)); + picker = new FixedResultPicker(PickResult.withSubchannel(subchannel)); break; case TRANSIENT_FAILURE: - picker = new Picker(PickResult.withError(stateInfo.getStatus())); + picker = new FixedResultPicker(PickResult.withError(stateInfo.getStatus())); break; default: throw new IllegalArgumentException("Unsupported state:" + currentState); @@ -154,52 +155,20 @@ public void requestConnection() { } } - /** - * No-op picker which doesn't add any custom picking logic. It just passes already known result - * received in constructor. - */ - private static final class Picker extends SubchannelPicker { - - private final PickResult result; - - Picker(PickResult result) { - this.result = checkNotNull(result, "result"); - } - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return result; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(Picker.class).add("result", result).toString(); - } - } - /** * Picker that requests connection during the first pick, and returns noResult. */ private final class RequestConnectionPicker extends SubchannelPicker { - private final Subchannel subchannel; private final AtomicBoolean connectionRequested = new AtomicBoolean(false); - RequestConnectionPicker(Subchannel subchannel) { - this.subchannel = checkNotNull(subchannel, "subchannel"); - } - @Override public PickResult pickSubchannel(PickSubchannelArgs args) { if (connectionRequested.compareAndSet(false, true)) { - helper.getSynchronizationContext().execute(new Runnable() { - @Override - public void run() { - subchannel.requestConnection(); - } - }); + helper.getSynchronizationContext().execute( + ShufflingPickFirstLoadBalancer.this::requestConnection); } return PickResult.withNoResult(); } } -} \ No newline at end of file +} diff --git a/examples/src/main/java/io/grpc/examples/deadline/README.md b/examples/src/main/java/io/grpc/examples/deadline/README.md new file mode 100644 index 00000000000..3c7646f1e5f --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/deadline/README.md @@ -0,0 +1,15 @@ +gRPC Deadline Example +===================== + +A Deadline is used to specify a point in time past which a client is unwilling to wait for a response from a server. +This simple idea is very important in building robust distributed systems. +Clients that do not wait around unnecessarily and servers that know when to give up processing requests will improve the resource utilization and latency of your system. + +Note that while some language APIs have the concept of a deadline, others use the idea of a timeout. +When an API asks for a deadline, you provide a point in time which the call should not go past. +A timeout is the max duration of time that the call can take. +A timeout can be converted to a deadline by adding the timeout to the current time when the application starts a call. + +This Example gives usage and implementation of Deadline on Server, Client and Propagation. + +Refer the gRPC documentation for more details on Deadlines https://grpc.io/docs/guides/deadlines/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/errordetails/README.md b/examples/src/main/java/io/grpc/examples/errordetails/README.md new file mode 100644 index 00000000000..8f241ba37a7 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/errordetails/README.md @@ -0,0 +1,16 @@ +gRPC Error Details Example +===================== + +If a gRPC call completes successfully the server returns an OK status to the client (depending on the language the OK status may or may not be directly used in your code). +But what happens if the call isn’t successful? + +This Example gives the usage and implementation of how return the error details if gRPC call not successful or fails +and how to set and read com.google.rpc.Status objects as google.rpc.Status error details. + +gRPC allows detailed error information to be encapsulated in protobuf messages, which are sent alongside the status codes. + +If an error occurs, gRPC returns one of its error status codes with error message that provides further error details about what happened. + +Refer the below links for more details on error details and status codes +- https://grpc.io/docs/guides/error/ +- https://github.com/grpc/grpc-java/blob/master/api/src/main/java/io/grpc/Status.java \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/errorhandling/README.md b/examples/src/main/java/io/grpc/examples/errorhandling/README.md new file mode 100644 index 00000000000..a920e939c86 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/errorhandling/README.md @@ -0,0 +1,27 @@ +gRPC Error Handling Example +===================== + +Error handling in gRPC is a critical aspect of designing reliable and robust distributed systems. +gRPC provides a standardized mechanism for handling errors using status codes, error details, and optional metadata. + +This Example gives the usage and implementation of how to handle the Errors/Exceptions in gRPC, +shows how to extract error information from a failed RPC and setting and reading RPC error details. + +If a gRPC call completes successfully the server returns an OK status to the client (depending on the language the OK status may or may not be directly used in your code). + +If an error occurs gRPC returns one of its error status codes with error message that provides further error details about what happened. + +Error Propagation: +- When an error occurs on the server, gRPC stops processing the RPC and sends the error (status code, description, and optional details) to the client. +- On the client side, the error can be handled based on the status code. + +Client Side Error Handling: + - The gRPC client typically throws an exception or returns an error object when an RPC fails. + +Server Side Error Handling: +- Servers use the gRPC API to return errors explicitly using the grpc library's status functions. + +gRPC uses predefined status codes to represent the outcome of an RPC call. These status codes are part of the Status object that is sent from the server to the client. +Each status code is accompanied by a human-readable description(Please refer https://github.com/grpc/grpc-java/blob/master/api/src/main/java/io/grpc/Status.java) + +Refer the gRPC documentation for more details on Error Handling https://grpc.io/docs/guides/error/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/experimental/README.md b/examples/src/main/java/io/grpc/examples/experimental/README.md new file mode 100644 index 00000000000..295b0801538 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/experimental/README.md @@ -0,0 +1,13 @@ +gRPC Compression Example +===================== + +This example shows how clients can specify compression options when performing RPCs, +and how to enable compressed(i,e gzip) requests/responses for only particular method and in case of all methods by using the interceptors. + +Compression is used to reduce the amount of bandwidth used when communicating between client/server or peers and +can be enabled or disabled based on call or message level for all languages. + +gRPC allows asymmetrically compressed communication, whereby a response may be compressed differently with the request, +or not compressed at all. + +Refer the gRPC documentation for more details on Compression https://grpc.io/docs/guides/compression/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/grpcproxy/README.md b/examples/src/main/java/io/grpc/examples/grpcproxy/README.md new file mode 100644 index 00000000000..cc13dc3d9d0 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/grpcproxy/README.md @@ -0,0 +1,22 @@ +gRPC Proxy Example +===================== + +A gRPC proxy is a component or tool that acts as an intermediary between gRPC clients and servers, +facilitating communication while offering additional capabilities. +Proxies are used in scenarios where you need to handle tasks like load balancing, routing, monitoring, +or providing a bridge between gRPC and other protocols. + +GrpcProxy itself can be used unmodified to proxy any service for both unary and streaming. +It doesn't care what type of messages are being used. +The Registry class causes it to be called for any inbound RPC, and uses plain bytes for messages which avoids marshalling +messages and the need for Protobuf schema information. + +We can run the Grpc Proxy with Route guide example to see how it works by running the below + +Route guide has unary and streaming RPCs which makes it a nice showcase, and we can run each in a separate terminal window. + +./build/install/examples/bin/route-guide-server +./build/install/examples/bin/grpc-proxy +./build/install/examples/bin/route-guide-client localhost:8981 + +you can verify the proxy is being used by shutting down the proxy and seeing the client fail. \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/header/HeaderClientInterceptor.java b/examples/src/main/java/io/grpc/examples/header/HeaderClientInterceptor.java index b9a73931299..2a60eeda6c4 100644 --- a/examples/src/main/java/io/grpc/examples/header/HeaderClientInterceptor.java +++ b/examples/src/main/java/io/grpc/examples/header/HeaderClientInterceptor.java @@ -52,7 +52,7 @@ public void start(Listener responseListener, Metadata headers) { public void onHeaders(Metadata headers) { /** * if you don't need receive header from server, - * you can use {@link io.grpc.stub.MetadataUtils#attachHeaders} + * you can use {@link io.grpc.stub.MetadataUtils#newAttachHeadersInterceptor} * directly to send header */ logger.info("header received from server:" + headers); diff --git a/examples/src/main/java/io/grpc/examples/header/README.md b/examples/src/main/java/io/grpc/examples/header/README.md new file mode 100644 index 00000000000..1563a2799cc --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/header/README.md @@ -0,0 +1,16 @@ +gRPC Custom Header Example +===================== + +This example gives the usage and implementation of how to create and process(send/receive) the custom headers between Client and Server +using the interceptors (HeaderServerInterceptor, ClientServerInterceptor) along with Metadata. + +Metadata is a side channel that allows clients and servers to provide information to each other that is associated with an RPC. +gRPC metadata is a key-value pair of data that is sent with initial or final gRPC requests or responses. +It is used to provide additional information about the call, such as authentication credentials, +tracing information, or custom headers. + +gRPC metadata can be used to send custom headers to the server or from the server to the client. +This can be used to implement application-specific features, such as load balancing, +rate limiting or providing detailed error messages from the server to the client. + +Refer the gRPC documentation for more on Metadata/Headers https://grpc.io/docs/guides/metadata/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/healthservice/README.md b/examples/src/main/java/io/grpc/examples/healthservice/README.md new file mode 100644 index 00000000000..181bd70977f --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/healthservice/README.md @@ -0,0 +1,10 @@ +gRPC Health Service Example +===================== + +The Health Service example provides a HelloWorld gRPC server that doesn't like short names along with a +health service. It also provides a client application which makes HelloWorld +calls and checks the health status. + +The client application also shows how the round robin load balancer can +utilize the health status to avoid making calls to a service that is +not actively serving. diff --git a/examples/src/main/java/io/grpc/examples/hedging/README.md b/examples/src/main/java/io/grpc/examples/hedging/README.md new file mode 100644 index 00000000000..0154e5c2cee --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/hedging/README.md @@ -0,0 +1,59 @@ +gRPC Hedging Example +===================== + +The Hedging example demonstrates that enabling hedging +can reduce tail latency. (Users should note that enabling hedging may introduce other overhead; +and in some scenarios, such as when some server resource gets exhausted for a period of time and +almost every RPC during that time has high latency or fails, hedging may make things worse. +Setting a throttle in the service config is recommended to protect the server from too many +inappropriate retry or hedging requests.) + +The server and the client in the example are basically the same as those in the +[hello world](src/main/java/io/grpc/examples/helloworld) example, except that the server mimics a +long tail of latency, and the client sends 2000 requests and can turn on and off hedging. + +To mimic the latency, the server randomly delays the RPC handling by 2 seconds at 10% chance, 5 +seconds at 5% chance, and 10 seconds at 1% chance. + +When running the client enabling the following hedging policy + + ```json + "hedgingPolicy": { + "maxAttempts": 3, + "hedgingDelay": "1s" + } + ``` +Then the latency summary in the client log is like the following + + ```text + Total RPCs sent: 2,000. Total RPCs failed: 0 + [Hedging enabled] + ======================== + 50% latency: 0ms + 90% latency: 6ms + 95% latency: 1,003ms + 99% latency: 2,002ms + 99.9% latency: 2,011ms + Max latency: 5,272ms + ======================== + ``` + +See [the section below](#to-build-the-examples) for how to build and run the example. The +executables for the server and the client are `hedging-hello-world-server` and +`hedging-hello-world-client`. + +To disable hedging, set environment variable `DISABLE_HEDGING_IN_HEDGING_EXAMPLE=true` before +running the client. That produces a latency summary in the client log like the following + + ```text + Total RPCs sent: 2,000. Total RPCs failed: 0 + [Hedging disabled] + ======================== + 50% latency: 0ms + 90% latency: 2,002ms + 95% latency: 5,002ms + 99% latency: 10,004ms + 99.9% latency: 10,007ms + Max latency: 10,007ms + ======================== + ``` diff --git a/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java b/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java index 81027587031..0e39581c98f 100644 --- a/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java +++ b/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java @@ -23,6 +23,8 @@ import java.io.IOException; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * Server that manages startup/shutdown of a {@code Greeter} server. @@ -31,11 +33,20 @@ public class HelloWorldServer { private static final Logger logger = Logger.getLogger(HelloWorldServer.class.getName()); private Server server; - private void start() throws IOException { /* The port on which the server should run */ int port = 50051; + /* + * By default gRPC uses a global, shared Executor.newCachedThreadPool() for gRPC callbacks into + * your application. This is convenient, but can cause an excessive number of threads to be + * created if there are many RPCs. It is often better to limit the number of threads your + * application uses for processing and let RPCs queue when the CPU is saturated. + * The appropriate number of threads varies heavily between applications. + * Async application code generally does not need more threads than CPU cores. + */ + ExecutorService executor = Executors.newFixedThreadPool(2); server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) + .executor(executor) .addService(new GreeterImpl()) .build() .start(); @@ -48,7 +59,12 @@ public void run() { try { HelloWorldServer.this.stop(); } catch (InterruptedException e) { + if (server != null) { + server.shutdownNow(); + } e.printStackTrace(System.err); + } finally { + executor.shutdown(); } System.err.println("*** server shut down"); } diff --git a/examples/src/main/java/io/grpc/examples/helloworld/README.md b/examples/src/main/java/io/grpc/examples/helloworld/README.md new file mode 100644 index 00000000000..5b11d4945c2 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/helloworld/README.md @@ -0,0 +1,7 @@ +gRPC Hello World Example +===================== +This Example gives the details about basic implementation of gRPC Client and Server along with +how the communication happens between them by sending a greeting message. + +Refer the gRPC documentation for more details on helloworld.proto specification, creation of gRPC services and +methods along with Execution process https://grpc.io/docs/languages/java/quickstart/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/keepalive/README.md b/examples/src/main/java/io/grpc/examples/keepalive/README.md new file mode 100644 index 00000000000..7b5b72665e7 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/keepalive/README.md @@ -0,0 +1,16 @@ +gRPC Keepalive Example +===================== + +This example gives the usage and implementation of the Keepalives methods, configurations in gRPC Client and +Server and how the communication happens between them. + +HTTP/2 PING-based keepalives are a way to keep an HTTP/2 connection alive even when there is no data being transferred. +This is done by periodically sending a PING Frames to the other end of the connection. +HTTP/2 keepalives can improve performance and reliability of HTTP/2 connections, +but it is important to configure the keepalive interval carefully. + +gRPC sends http2 pings on the transport to detect if the connection is down. +If the ping is not acknowledged by the other side within a certain period, the connection will be closed. +Note that pings are only necessary when there's no activity on the connection. + +Refer the gRPC documentation for more on Keepalive details and configurations https://grpc.io/docs/guides/keepalive/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/loadbalance/README.md b/examples/src/main/java/io/grpc/examples/loadbalance/README.md new file mode 100644 index 00000000000..0d19d2f3335 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/loadbalance/README.md @@ -0,0 +1,20 @@ +gRPC Load Balance Example +===================== + +One of the key features of gRPC is load balancing, which allows requests from clients to be distributed across multiple servers. +This helps prevent any one server from becoming overloaded and allows the system to scale up by adding more servers. + +A gRPC load balancing policy is given a list of server IP addresses by the name resolver. +The policy is responsible for maintaining connections (subchannels) to the servers and picking a connection to use when an RPC is sent. + +By default, the pick_first policy will be used. +This policy actually does no load balancing but just tries each address it gets from the name resolver and uses the first one it can connect to. +By updating the gRPC service config you can also switch to using round_robin that connects to every address it gets and rotates through the connected backends for each RPC. +There are also some other load balancing policies available, but the exact set varies by language. + +This example gives the details about how to implement Load Balance in gRPC, If the built-in policies does not meet your requirements +you can implement your own custom load balance [Custom Load Balance](src/main/java/io/grpc/examples/customloadbalance) + +gRPC supports both client side and server side load balancing but by default gRPC uses client side load balancing. + +Refer the gRPC documentation for more details on Load Balancing https://grpc.io/blog/grpc-load-balancing/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/manualflowcontrol/BidiBlockingClient.java b/examples/src/main/java/io/grpc/examples/manualflowcontrol/BidiBlockingClient.java new file mode 100644 index 00000000000..902d46c8cc6 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/manualflowcontrol/BidiBlockingClient.java @@ -0,0 +1,286 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.manualflowcontrol; + +import com.google.protobuf.ByteString; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.StatusException; +import io.grpc.examples.manualflowcontrol.StreamingGreeterGrpc.StreamingGreeterBlockingV2Stub; +import io.grpc.stub.BlockingClientCall; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + + +/** + * A class that tries multiple ways to do blocking bidi streaming + * communication with an echo server + */ +public class BidiBlockingClient { + + private static final Logger logger = Logger.getLogger(BidiBlockingClient.class.getName()); + + /** + * Greet server. If provided, the first element of {@code args} is the name to use in the + * greeting. The second argument is the target server. You can see the multiplexing in the server + * logs. + */ + public static void main(String[] args) throws Exception { + System.setProperty("java.util.logging.SimpleFormatter.format", "%1$tH:%1$tM:%1$tS %5$s%6$s%n"); + + // Access a service running on the local machine on port 50051 + String target = "localhost:50051"; + // Allow passing in the user and target strings as command line arguments + if (args.length > 0) { + if ("--help".equals(args[0])) { + System.err.println("Usage: [target]\n"); + System.err.println(" target The server to connect to. Defaults to " + target); + System.exit(1); + } + target = args[0]; + } + + // Create a communication channel to the server, known as a Channel. Channels are thread-safe + // and reusable. It is common to create channels at the beginning of your application and reuse + // them until the application shuts down. + // + // For the example we use plaintext insecure credentials to avoid needing TLS certificates. To + // use TLS, use TlsChannelCredentials instead. + ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) + .build(); + StreamingGreeterBlockingV2Stub blockingStub = StreamingGreeterGrpc.newBlockingV2Stub(channel); + List echoInput = names(); + try { + long start = System.currentTimeMillis(); + List twoThreadResult = useTwoThreads(blockingStub, echoInput); + long finish = System.currentTimeMillis(); + + System.out.println("The echo requests and results were:"); + printResultMessage("Input", echoInput, 0L); + printResultMessage("2 threads", twoThreadResult, finish - start); + } finally { + // ManagedChannels use resources like threads and TCP connections. To prevent leaking these + // resources the channel should be shut down when it will no longer be used. If it may be used + // again leave it running. + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + } + + private static void printResultMessage(String type, List result, long millis) { + String msg = String.format("%-32s: %2d, %.3f sec", type, result.size(), millis/1000.0); + logger.info(msg); + } + + private static void logMethodStart(String method) { + logger.info("--------------------- Starting to process using method: " + method); + } + + /** + * Create 2 threads, one that writes all values, and one that reads until the stream closes. + */ + private static List useTwoThreads(StreamingGreeterBlockingV2Stub blockingStub, + List valuesToWrite) throws InterruptedException { + logMethodStart("Two Threads"); + + List readValues = new ArrayList<>(); + final BlockingClientCall stream = blockingStub.sayHelloStreaming(); + + Thread reader = new Thread(null, + new Runnable() { + @Override + public void run() { + int count = 0; + try { + while (stream.hasNext()) { + readValues.add(stream.read().getMessage()); + if (++count % 10 == 0) { + logger.info("Finished " + count + " reads"); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + stream.cancel("Interrupted", e); + } catch (StatusException e) { + logger.warning("Encountered error while reading: " + e); + } + } + },"reader"); + + Thread writer = new Thread(null, + new Runnable() { + @Override + public void run() { + ByteString padding = createPadding(); + int count = 0; + Iterator iterator = valuesToWrite.iterator(); + boolean hadProblem = false; + try { + while (iterator.hasNext()) { + if (!stream.write(HelloRequest.newBuilder().setName(iterator.next()).setPadding(padding) + .build())) { + logger.warning("Stream closed before writes completed"); + hadProblem = true; + break; + } + if (++count % 10 == 0) { + logger.info("Finished " + count + " writes"); + } + } + if (!hadProblem) { + logger.info("Completed writes"); + stream.halfClose(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + stream.cancel("Interrupted", e); + } catch (StatusException e) { + logger.warning("Encountered error while writing: " + e); + } + } + }, "writer"); + + writer.start(); + reader.start(); + writer.join(); + reader.join(); + + return readValues; + } + + private static ByteString createPadding() { + int multiple = 50; + ByteBuffer data = ByteBuffer.allocate(1024 * multiple); + + for (int i = 0; i < multiple * 1024 / 4; i++) { + data.putInt(4 * i, 1111); + } + + return ByteString.copyFrom(data); + } + + + private static List names() { + return Arrays.asList( + "Sophia", + "Jackson", + "Emma", + "Aiden", + "Olivia", + "Lucas", + "Ava", + "Liam", + "Mia", + "Noah", + "Isabella", + "Ethan", + "Riley", + "Mason", + "Aria", + "Caden", + "Zoe", + "Oliver", + "Charlotte", + "Elijah", + "Lily", + "Grayson", + "Layla", + "Jacob", + "Amelia", + "Michael", + "Emily", + "Benjamin", + "Madelyn", + "Carter", + "Aubrey", + "James", + "Adalyn", + "Jayden", + "Madison", + "Logan", + "Chloe", + "Alexander", + "Harper", + "Caleb", + "Abigail", + "Ryan", + "Aaliyah", + "Luke", + "Avery", + "Daniel", + "Evelyn", + "Jack", + "Kaylee", + "William", + "Ella", + "Owen", + "Ellie", + "Gabriel", + "Scarlett", + "Matthew", + "Arianna", + "Connor", + "Hailey", + "Jayce", + "Nora", + "Isaac", + "Addison", + "Sebastian", + "Brooklyn", + "Henry", + "Hannah", + "Muhammad", + "Mila", + "Cameron", + "Leah", + "Wyatt", + "Elizabeth", + "Dylan", + "Sarah", + "Nathan", + "Eliana", + "Nicholas", + "Mackenzie", + "Julian", + "Peyton", + "Eli", + "Maria", + "Levi", + "Grace", + "Isaiah", + "Adeline", + "Landon", + "Elena", + "David", + "Anna", + "Christian", + "Victoria", + "Andrew", + "Camilla", + "Brayden", + "Lillian", + "John", + "Natalie", + "Lincoln" + ); + } +} diff --git a/examples/src/main/java/io/grpc/examples/manualflowcontrol/ManualFlowControlServer.java b/examples/src/main/java/io/grpc/examples/manualflowcontrol/ManualFlowControlServer.java index de8142596ea..3b7f980e08c 100644 --- a/examples/src/main/java/io/grpc/examples/manualflowcontrol/ManualFlowControlServer.java +++ b/examples/src/main/java/io/grpc/examples/manualflowcontrol/ManualFlowControlServer.java @@ -72,6 +72,7 @@ public void run() { // Give gRPC a StreamObserver that can observe and process incoming requests. return new StreamObserver() { + int cnt = 0; @Override public void onNext(HelloRequest request) { // Process the request and send a response or an error. @@ -81,7 +82,8 @@ public void onNext(HelloRequest request) { logger.info("--> " + name); // Simulate server "work" - Thread.sleep(100); + int sleepMillis = ++cnt % 20 == 0 ? 2000 : 100; + Thread.sleep(sleepMillis); // Send a response. String message = "Hello " + name; diff --git a/examples/src/main/java/io/grpc/examples/manualflowcontrol/README.md b/examples/src/main/java/io/grpc/examples/manualflowcontrol/README.md index a30688cea15..f700d428aca 100644 --- a/examples/src/main/java/io/grpc/examples/manualflowcontrol/README.md +++ b/examples/src/main/java/io/grpc/examples/manualflowcontrol/README.md @@ -1,5 +1,5 @@ -gRPC Manual Flow Control Example -===================== +# gRPC Manual Flow Control Example + Flow control is relevant for streaming RPC calls. By default, gRPC will handle dealing with flow control. However, for specific @@ -25,14 +25,13 @@ value. ### Outgoing Flow Control -The underlying layer (such as Netty) will make the write wait when there is no -space to write the next message. This causes the request stream to go into -a not ready state and the outgoing onNext method invocation waits. You can -explicitly check that the stream is ready for writing before calling onNext to -avoid blocking. This is done with `CallStreamObserver.isReady()`. You can -utilize this to start doing reads, which may allow -the other side of the channel to complete a write and then to do its own reads, -thereby avoiding deadlock. +The underlying layer (such as Netty) manages a buffer for outgoing messages. If +you write messages faster than they can be sent over the network, this buffer +will grow, which can eventually lead to an OutOfMemoryError. The outgoing onNext +method invocation does not block when this happens. Therefore, you should +explicitly check that the stream is ready for writing via +`CallStreamObserver.isReady()` before generating messages to avoid buffering +excessive amounts of data in memory. ### Incoming Manual Flow Control @@ -71,6 +70,7 @@ When you are ready to begin processing the next value from the stream call `serverCallStreamObserver.request(1)` ### Related documents + Also see [gRPC Flow Control Users Guide][user guide] - [user guide]: https://grpc.io/docs/guides/flow-control \ No newline at end of file +[user guide]: https://grpc.io/docs/guides/flow-control diff --git a/examples/src/main/java/io/grpc/examples/multiplex/README.md b/examples/src/main/java/io/grpc/examples/multiplex/README.md new file mode 100644 index 00000000000..fb24642a41b --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/multiplex/README.md @@ -0,0 +1,20 @@ +gRPC Multiplex Example +===================== + +gRPC multiplexing refers to the ability of a single gRPC connection to handle multiple independent streams of communication simultaneously. +This is part of the HTTP/2 protocol on which gRPC is built. +Each gRPC connection supports multiple streams that can carry different RPCs, making it highly efficient for high-throughput, low-latency communication. + +In gRPC, sharing resources like channels and servers can improve efficiency and resource utilization. + +- Sharing gRPC Channels and Servers + + 1. Shared gRPC Channel: + - A single gRPC channel can be used by multiple stubs, enabling different service clients to communicate over the same connection. + - This minimizes the overhead of establishing and managing multiple connections + + 2. Shared gRPC Server: + - A single gRPC channel can be used by multiple stubs, enabling different service clients to communicate over the same connection. + - This minimizes the overhead of establishing and managing multiple connections + +This example demonstrates how to implement a gRPC server that serves both a GreetingService and an EchoService, and a client that shares a single channel across multiple stubs for both services. \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/nameresolve/README.md b/examples/src/main/java/io/grpc/examples/nameresolve/README.md new file mode 100644 index 00000000000..36c8d7e2a6b --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/nameresolve/README.md @@ -0,0 +1,22 @@ +gRPC Name Resolve Example +===================== + +This example explains standard name resolution process and how to implement it using the Name Resolver component. + +Name Resolution is fundamentally about Service Discovery. +Name Resolution refers to the process of converting a name into an address and +Name Resolver is the component that implements the Name Resolution process. + +When sending gRPC Request, Client must determine the IP address of the Service Name, +By Default DNS Name Resolution will be used when request received from the gRPC client. + +The Name Resolver in gRPC is necessary because clients often don’t know the exact IP address or port of the server +they need to connect to. + +The client registers an implementation of a **name resolver provider** to a process-global **registry** close to the start of the process. +The name resolver provider will be called by the **gRPC library** with a **target strings** intended for the custom name resolver. +Given that target string, the name resolver provider will return an instance of a **name resolver**, +which will interact with the client connection to direct the request according to the target string. + +Refer the gRPC documentation for more on Name Resolution and Custom Name Resolution +https://grpc.io/docs/guides/custom-name-resolution/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/preserialized/README.md b/examples/src/main/java/io/grpc/examples/preserialized/README.md new file mode 100644 index 00000000000..d49b3507d03 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/preserialized/README.md @@ -0,0 +1,18 @@ +gRPC Pre-Serialized Messages Example +===================== + +This example gives the usage and implementation of pre-serialized request and response messages +communication/exchange between grpc client and server by using ByteArrayMarshaller which produces +a byte[] instead of decoding into typical POJOs. + +This is a performance optimization that can be useful if you read the request/response from on-disk or a database +where it is already serialized, or if you need to send the same complicated message to many clients and servers. +The same approach can avoid deserializing requests/responses, to be stored in a database. + +It shows how to modify MethodDescriptor to use bytes as the response instead of HelloReply. By +adjusting toBuilder() you can choose which of the request and response are bytes. +The generated bindService() uses ServerCalls to make RPC handlers, Since the generated +bindService() won't expect byte[] in the AsyncService, this uses ServerCalls directly. + +Stubs use ClientCalls to send RPCs, Since the generated stub won't have byte[] in its +method signature, this uses ClientCalls directly. \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/retrying/README.md b/examples/src/main/java/io/grpc/examples/retrying/README.md new file mode 100644 index 00000000000..bb29ce75e43 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/retrying/README.md @@ -0,0 +1,27 @@ +gRPC Retrying Example +===================== + +The Retrying example provides a HelloWorld gRPC client & +server which demos the effect of client retry policy configured on the [ManagedChannel]( +https://github.com/grpc/grpc-java/blob/master/api/src/main/java/io/grpc/ManagedChannel.java) via [gRPC ServiceConfig]( +https://github.com/grpc/grpc/blob/master/doc/service_config.md). Retry policy implementation & +configuration details are outlined in the [proposal](https://github.com/grpc/proposal/blob/master/A6-client-retries.md). + +This retrying example is very similar to the [hedging example](https://github.com/grpc/grpc-java/tree/master/examples/src/main/java/io/grpc/examples/hedging) in its setup. +The [RetryingHelloWorldServer](src/main/java/io/grpc/examples/retrying/RetryingHelloWorldServer.java) responds with +a status UNAVAILABLE error response to a specified percentage of requests to simulate server resource exhaustion and +general flakiness. The [RetryingHelloWorldClient](src/main/java/io/grpc/examples/retrying/RetryingHelloWorldClient.java) makes +a number of sequential requests to the server, several of which will be retried depending on the configured policy in +[retrying_service_config.json](https://github.com/grpc/grpc-java/blob/master/examples/src/main/resources/io/grpc/examples/retrying/retrying_service_config.json). Although +the requests are blocking unary calls for simplicity, these could easily be changed to future unary calls in order to +test the result of request concurrency with retry policy enabled. + +One can experiment with the [RetryingHelloWorldServer](src/main/java/io/grpc/examples/retrying/RetryingHelloWorldServer.java) +failure conditions to simulate server throttling, as well as alter policy values in the [retrying_service_config.json]( +https://github.com/grpc/grpc-java/blob/master/examples/src/main/resources/io/grpc/examples/retrying/retrying_service_config.json) to see their effects. To disable retrying +entirely, set environment variable `DISABLE_RETRYING_IN_RETRYING_EXAMPLE=true` before running the client. +Disabling the retry policy should produce many more failed gRPC calls as seen in the output log. + +See [the section](https://github.com/grpc/grpc-java/tree/master/examples#-to-build-the-examples) for how to build and run the example. The +executables for the server and the client are `retrying-hello-world-server` and +`retrying-hello-world-client`. diff --git a/examples/src/main/java/io/grpc/examples/routeguide/README.md b/examples/src/main/java/io/grpc/examples/routeguide/README.md new file mode 100644 index 00000000000..2528b26410c --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/routeguide/README.md @@ -0,0 +1,24 @@ +gRPC Route Guide Example +===================== + +This example illustrates how to implement and use a gRPC server and client for a RouteGuide service, +which demonstrates all 4 types of gRPC methods (unary, client streaming, server streaming, and bidirectional streaming). +Additionally, the service loads geographic features from a JSON file [route_guide_db.json](https://github.com/grpc/grpc-java/blob/master/examples/src/main/resources/io/grpc/examples/routeguide/route_guide_db.json) and retrieves features based on latitude and longitude. + +The route_guide.proto file defines a gRPC service with 4 types of RPC methods, showcasing different communication patterns between client and server. +1. Unary RPC + - rpc GetFeature(Point) returns (Feature) {} +2. Server-Side Streaming RPC + - rpc ListFeatures(Rectangle) returns (stream Feature) {} +3. Client-Side Streaming RPC + - rpc RecordRoute(stream Point) returns (RouteSummary) {} +4. Bidirectional Streaming RPC + - rpc RouteChat(stream RouteNote) returns (stream RouteNote) {} + +These RPC methods illustrate the versatility of gRPC in handling various communication patterns, +from simple request-response interactions to complex bidirectional streaming scenarios. + +For more details, refer to the full route_guide.proto file on GitHub: https://github.com/grpc/grpc-java/blob/master/examples/src/main/proto/route_guide.proto + +Refer the gRPC documentation for more details on creation, build and execution of route guide example with explanation +https://grpc.io/docs/languages/java/basics/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/waitforready/README.md b/examples/src/main/java/io/grpc/examples/waitforready/README.md new file mode 100644 index 00000000000..1e294b453b6 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/waitforready/README.md @@ -0,0 +1,29 @@ +gRPC Wait-For-Ready Example +===================== + +This example gives the usage and implementation of the Wait-For-Ready feature. + +This feature can be activated on a client stub, ensuring that Remote Procedure Calls (RPCs) are held until the server is ready to receive them. +By waiting for the server to become available before sending requests, this mechanism enhances reliability, +particularly in situations where server availability may be delayed or unpredictable. + +When an RPC is initiated and the channel fails to connect to the server, its behavior depends on the Wait-for-Ready option: + +- Without Wait-for-Ready (Default Behavior): + + - The RPC will immediately fail if the channel cannot establish a connection, providing prompt feedback about the connectivity issue. + +- With Wait-for-Ready: + + - The RPC will not fail immediately. Instead, it will be queued and will wait until the connection is successfully established. + This approach is beneficial for handling temporary network disruptions more gracefully, ensuring the RPC is eventually executed once the connection is ready. + + +Example gives the Simple client that requests a greeting from the HelloWorldServer and defines waitForReady on the stub. + +To test this flow need to follow below steps: +- run this client without a server running(client rpc should hang) +- start the server (client rpc should complete) +- run this client again (client rpc should complete nearly immediately) + +Refer the gRPC documentation for more on Wait-For-Ready https://grpc.io/docs/guides/wait-for-ready/ \ No newline at end of file diff --git a/examples/src/main/proto/hello_streaming.proto b/examples/src/main/proto/hello_streaming.proto index 325b9093b0c..b4f0f5287dd 100644 --- a/examples/src/main/proto/hello_streaming.proto +++ b/examples/src/main/proto/hello_streaming.proto @@ -29,6 +29,7 @@ service StreamingGreeter { // The request message containing the user's name. message HelloRequest { string name = 1; + bytes padding = 2; } // The response message containing the greetings diff --git a/gae-interop-testing/gae-jdk8/build.gradle b/gae-interop-testing/gae-jdk8/build.gradle index a09a8e793c0..07033f403de 100644 --- a/gae-interop-testing/gae-jdk8/build.gradle +++ b/gae-interop-testing/gae-jdk8/build.gradle @@ -14,10 +14,6 @@ buildscript { // Configuration for building - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } - } dependencies { classpath 'com.squareup.okhttp:okhttp:2.7.4' } @@ -33,13 +29,6 @@ plugins { description = 'gRPC: gae interop testing (jdk8)' -repositories { - // repositories for Jar's you access in your code - mavenLocal() - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } -} - dependencies { providedCompile group: 'javax.servlet', name: 'servlet-api', version:'2.5' runtimeOnly 'com.google.appengine:appengine-api-1.0-sdk:1.9.59' @@ -53,7 +42,11 @@ dependencies { implementation libraries.junit implementation libraries.protobuf.java runtimeOnly libraries.netty.tcnative, libraries.netty.tcnative.classes - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } tasks.named("compileJava").configure { @@ -65,6 +58,7 @@ def createDefaultVersion() { return new java.text.SimpleDateFormat("yyyyMMdd't'HHmmss").format(new Date()) } +def nonShadowedProject = project // [START model] appengine { // App Engine tasks configuration @@ -74,13 +68,13 @@ appengine { deploy { // deploy configuration - projectId = 'GCLOUD_CONFIG' + projectId = nonShadowedProject.findProperty('gaeProjectId') ?: 'GCLOUD_CONFIG' // default - stop the current version - stopPreviousVersion = System.getProperty('gaeStopPreviousVersion') ?: true + stopPreviousVersion = nonShadowedProject.findProperty('gaeStopPreviousVersion') ?: true // default - do not make this the promoted version - promote = System.getProperty('gaePromote') ?: false - // Use -DgaeDeployVersion if set, otherwise the version is null and the plugin will generate it - version = System.getProperty('gaeDeployVersion', createDefaultVersion()) + promote = nonShadowedProject.findProperty('gaePromote') ?: false + // Use -PgaeDeployVersion if set, otherwise the version is null and the plugin will generate it + version = nonShadowedProject.findProperty('gaeDeployVersion') ?: createDefaultVersion() } } // [END model] @@ -90,6 +84,10 @@ version = '1.0-SNAPSHOT' // Version in generated output /** Returns the service name. */ String getGaeProject() { + def configuredProjectId = appengine.deploy.projectId + if (!"GCLOUD_CONFIG".equals(configuredProjectId)) { + return configuredProjectId + } def stream = new ByteArrayOutputStream() exec { executable 'gcloud' @@ -117,11 +115,8 @@ String getAppUrl(String project, String service, String version) { } tasks.register("runInteropTestRemote") { - dependsOn appengineDeploy + mustRunAfter appengineDeploy doLast { - // give remote app some time to settle down - sleep(20000) - def appUrl = getAppUrl( getGaeProject(), getService(project.getProjectDir().toPath()), diff --git a/gae-interop-testing/gae-jdk8/src/main/webapp/WEB-INF/appengine-web.xml b/gae-interop-testing/gae-jdk8/src/main/webapp/WEB-INF/appengine-web.xml index 2fcbe5d8221..715906ada47 100644 --- a/gae-interop-testing/gae-jdk8/src/main/webapp/WEB-INF/appengine-web.xml +++ b/gae-interop-testing/gae-jdk8/src/main/webapp/WEB-INF/appengine-web.xml @@ -14,6 +14,6 @@ java-gae-interop-test - java11 + java17 diff --git a/gcp-csm-observability/build.gradle b/gcp-csm-observability/build.gradle index e29a56b1052..bda54ca8146 100644 --- a/gcp-csm-observability/build.gradle +++ b/gcp-csm-observability/build.gradle @@ -28,5 +28,9 @@ dependencies { libraries.opentelemetry.sdk.testing, libraries.assertj.core // opentelemetry.sdk.testing uses compileOnly for this dep - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } diff --git a/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/MetadataExchanger.java b/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/MetadataExchanger.java index bf76c2532bc..5f05d52c7e7 100644 --- a/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/MetadataExchanger.java +++ b/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/MetadataExchanger.java @@ -16,7 +16,6 @@ package io.grpc.gcp.csm.observability; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.io.BaseEncoding; import com.google.protobuf.Struct; @@ -29,12 +28,9 @@ import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.Status; -import io.grpc.internal.JsonParser; -import io.grpc.internal.JsonUtil; import io.grpc.opentelemetry.InternalOpenTelemetryPlugin; import io.grpc.protobuf.ProtoUtils; import io.grpc.xds.ClusterImplLoadBalancerProvider; -import io.grpc.xds.InternalGrpcBootstrapperImpl; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.common.AttributesBuilder; @@ -43,8 +39,6 @@ import java.net.URI; import java.util.Map; import java.util.function.Consumer; -import java.util.logging.Level; -import java.util.logging.Logger; /** * OpenTelemetryPlugin implementing metadata-based workload property exchange for both client and @@ -52,7 +46,6 @@ * and remote details to metrics. */ final class MetadataExchanger implements InternalOpenTelemetryPlugin { - private static final Logger logger = Logger.getLogger(MetadataExchanger.class.getName()); private static final AttributeKey CLOUD_PLATFORM = AttributeKey.stringKey("cloud.platform"); @@ -89,11 +82,10 @@ final class MetadataExchanger implements InternalOpenTelemetryPlugin { public MetadataExchanger() { this( addOtelResourceAttributes(new GCPResourceProvider().getAttributes()), - System::getenv, - InternalGrpcBootstrapperImpl::getJsonContent); + System::getenv); } - MetadataExchanger(Attributes platformAttributes, Lookup env, Supplier xdsBootstrap) { + MetadataExchanger(Attributes platformAttributes, Lookup env) { String type = platformAttributes.get(CLOUD_PLATFORM); String canonicalService = env.get("CSM_CANONICAL_SERVICE_NAME"); Struct.Builder struct = Struct.newBuilder(); @@ -121,7 +113,7 @@ public MetadataExchanger() { localMetadata = BaseEncoding.base64().encode(struct.build().toByteArray()); localAttributes = Attributes.builder() - .put("csm.mesh_id", nullIsUnknown(getMeshId(xdsBootstrap))) + .put("csm.mesh_id", nullIsUnknown(env.get("CSM_MESH_ID"))) .put("csm.workload_canonical_service", nullIsUnknown(canonicalService)) .build(); } @@ -162,29 +154,6 @@ private static Attributes addOtelResourceAttributes(Attributes platformAttribute return builder.build(); } - @VisibleForTesting - static String getMeshId(Supplier xdsBootstrap) { - try { - @SuppressWarnings("unchecked") - Map rawBootstrap = (Map) JsonParser.parse(xdsBootstrap.get()); - Map node = JsonUtil.getObject(rawBootstrap, "node"); - String id = JsonUtil.getString(node, "id"); - Preconditions.checkNotNull(id, "id"); - String[] parts = id.split("/", 6); - if (!(parts.length == 6 - && parts[0].equals("projects") - && parts[2].equals("networks") - && parts[3].startsWith("mesh:") - && parts[4].equals("nodes"))) { - throw new Exception("node id didn't match mesh format: " + id); - } - return parts[3].substring("mesh:".length()); - } catch (Exception e) { - logger.log(Level.INFO, "Failed to determine mesh ID for CSM", e); - return null; - } - } - private void addLabels(AttributesBuilder to, Struct struct) { to.putAll(localAttributes); Map remote = struct.getFieldsMap(); diff --git a/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/CsmObservabilityTest.java b/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/CsmObservabilityTest.java index 878bf30ce34..aba2c43c44f 100644 --- a/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/CsmObservabilityTest.java +++ b/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/CsmObservabilityTest.java @@ -77,17 +77,14 @@ public void tearDown() { @Test public void unknownDataExchange() throws Exception { - String xdsBootstrap = ""; MetadataExchanger clientExchanger = new MetadataExchanger( Attributes.builder().build(), - ImmutableMap.of()::get, - () -> xdsBootstrap); + ImmutableMap.of()::get); CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) .sdk(openTelemetryTesting.getOpenTelemetry()); MetadataExchanger serverExchanger = new MetadataExchanger( Attributes.builder().build(), - ImmutableMap.of()::get, - () -> xdsBootstrap); + ImmutableMap.of()::get); CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) .sdk(openTelemetryTesting.getOpenTelemetry()); @@ -140,11 +137,9 @@ public void unknownDataExchange() throws Exception { @Test public void nonCsmServer() throws Exception { - String xdsBootstrap = ""; MetadataExchanger clientExchanger = new MetadataExchanger( Attributes.builder().build(), - ImmutableMap.of()::get, - () -> xdsBootstrap); + ImmutableMap.of()::get); CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) .sdk(openTelemetryTesting.getOpenTelemetry()); @@ -205,19 +200,16 @@ public void nonCsmServer() throws Exception { @Test public void nonCsmClient() throws Exception { - String xdsBootstrap = ""; MetadataExchanger clientExchanger = new MetadataExchanger( Attributes.builder() .put(stringKey("cloud.platform"), "gcp_kubernetes_engine") .build(), - ImmutableMap.of()::get, - () -> xdsBootstrap); + ImmutableMap.of()::get); CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) .sdk(openTelemetryTesting.getOpenTelemetry()); MetadataExchanger serverExchanger = new MetadataExchanger( Attributes.builder().build(), - ImmutableMap.of()::get, - () -> xdsBootstrap); + ImmutableMap.of()::get); CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) .sdk(openTelemetryTesting.getOpenTelemetry()); @@ -262,11 +254,6 @@ public void nonCsmClient() throws Exception { @Test public void k8sExchange() throws Exception { - // Purposefully use a different project ID in the bootstrap than the resource, as the mesh could - // be in a different project than the running account. - String clientBootstrap = "{\"node\": {" - + "\"id\": \"projects/12/networks/mesh:mymesh/nodes/a6420022-cbc5-4e10-808c-507e3fc95f2e\"" - + "}}"; MetadataExchanger clientExchanger = new MetadataExchanger( Attributes.builder() .put(stringKey("cloud.platform"), "gcp_kubernetes_engine") @@ -277,13 +264,10 @@ public void k8sExchange() throws Exception { .build(), ImmutableMap.of( "CSM_CANONICAL_SERVICE_NAME", "canon-service-is-a-client", - "CSM_WORKLOAD_NAME", "best-client")::get, - () -> clientBootstrap); + "CSM_WORKLOAD_NAME", "best-client", + "CSM_MESH_ID", "mymesh")::get); CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) .sdk(openTelemetryTesting.getOpenTelemetry()); - String serverBootstrap = "{\"node\": {" - + "\"id\": \"projects/34/networks/mesh:meshhh/nodes/4969ef19-24b6-44c0-baf3-86d188ff5967\"" - + "}}"; MetadataExchanger serverExchanger = new MetadataExchanger( Attributes.builder() .put(stringKey("cloud.platform"), "gcp_kubernetes_engine") @@ -295,8 +279,8 @@ public void k8sExchange() throws Exception { .build(), ImmutableMap.of( "CSM_CANONICAL_SERVICE_NAME", "server-has-a-single-name", - "CSM_WORKLOAD_NAME", "fast-server")::get, - () -> serverBootstrap); + "CSM_WORKLOAD_NAME", "fast-server", + "CSM_MESH_ID", "meshhh")::get); CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) .sdk(openTelemetryTesting.getOpenTelemetry()); @@ -366,11 +350,6 @@ public void k8sExchange() throws Exception { @Test public void gceExchange() throws Exception { - // Purposefully use a different project ID in the bootstrap than the resource, as the mesh could - // be in a different project than the running account. - String clientBootstrap = "{\"node\": {" - + "\"id\": \"projects/12/networks/mesh:mymesh/nodes/a6420022-cbc5-4e10-808c-507e3fc95f2e\"" - + "}}"; MetadataExchanger clientExchanger = new MetadataExchanger( Attributes.builder() .put(stringKey("cloud.platform"), "gcp_compute_engine") @@ -379,13 +358,10 @@ public void gceExchange() throws Exception { .build(), ImmutableMap.of( "CSM_CANONICAL_SERVICE_NAME", "canon-service-is-a-client", - "CSM_WORKLOAD_NAME", "best-client")::get, - () -> clientBootstrap); + "CSM_WORKLOAD_NAME", "best-client", + "CSM_MESH_ID", "mymesh")::get); CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) .sdk(openTelemetryTesting.getOpenTelemetry()); - String serverBootstrap = "{\"node\": {" - + "\"id\": \"projects/34/networks/mesh:meshhh/nodes/4969ef19-24b6-44c0-baf3-86d188ff5967\"" - + "}}"; MetadataExchanger serverExchanger = new MetadataExchanger( Attributes.builder() .put(stringKey("cloud.platform"), "gcp_compute_engine") @@ -395,8 +371,8 @@ public void gceExchange() throws Exception { .build(), ImmutableMap.of( "CSM_CANONICAL_SERVICE_NAME", "server-has-a-single-name", - "CSM_WORKLOAD_NAME", "fast-server")::get, - () -> serverBootstrap); + "CSM_WORKLOAD_NAME", "fast-server", + "CSM_MESH_ID", "meshhh")::get); CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) .sdk(openTelemetryTesting.getOpenTelemetry()); @@ -456,9 +432,6 @@ public void gceExchange() throws Exception { @Test public void trailersOnly() throws Exception { - String clientBootstrap = "{\"node\": {" - + "\"id\": \"projects/12/networks/mesh:mymesh/nodes/a6420022-cbc5-4e10-808c-507e3fc95f2e\"" - + "}}"; MetadataExchanger clientExchanger = new MetadataExchanger( Attributes.builder() .put(stringKey("cloud.platform"), "gcp_compute_engine") @@ -467,13 +440,11 @@ public void trailersOnly() throws Exception { .build(), ImmutableMap.of( "CSM_CANONICAL_SERVICE_NAME", "canon-service-is-a-client", - "CSM_WORKLOAD_NAME", "best-client")::get, - () -> clientBootstrap); + "CSM_WORKLOAD_NAME", "best-client", + "CSM_MESH_ID", "mymesh")::get); CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) .sdk(openTelemetryTesting.getOpenTelemetry()); - String serverBootstrap = "{\"node\": {" - + "\"id\": \"projects/34/networks/mesh:meshhh/nodes/4969ef19-24b6-44c0-baf3-86d188ff5967\"" - + "}}"; + MetadataExchanger serverExchanger = new MetadataExchanger( Attributes.builder() .put(stringKey("cloud.platform"), "gcp_compute_engine") @@ -483,8 +454,8 @@ public void trailersOnly() throws Exception { .build(), ImmutableMap.of( "CSM_CANONICAL_SERVICE_NAME", "server-has-a-single-name", - "CSM_WORKLOAD_NAME", "fast-server")::get, - () -> serverBootstrap); + "CSM_WORKLOAD_NAME", "fast-server", + "CSM_MESH_ID", "meshhh")::get); CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) .sdk(openTelemetryTesting.getOpenTelemetry()); diff --git a/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/MetadataExchangerTest.java b/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/MetadataExchangerTest.java index 20665e502f7..cc3472be182 100644 --- a/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/MetadataExchangerTest.java +++ b/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/MetadataExchangerTest.java @@ -33,56 +33,11 @@ /** Tests for {@link MetadataExchanger}. */ @RunWith(JUnit4.class) public final class MetadataExchangerTest { - @Test - public void getMeshId_findsMeshId() { - assertThat(MetadataExchanger.getMeshId(() -> - "{\"node\":{\"id\":\"projects/12/networks/mesh:mine/nodes/uu-id\"}}")) - .isEqualTo("mine"); - assertThat(MetadataExchanger.getMeshId(() -> - "{\"node\":{\"id\":\"projects/1234567890/networks/mesh:mine/nodes/uu-id\", " - + "\"unknown\": \"\"}, \"unknown\": \"\"}")) - .isEqualTo("mine"); - } - - @Test - public void getMeshId_returnsNullOnBadMeshId() { - assertThat(MetadataExchanger.getMeshId( - () -> "[\"node\"]")) - .isNull(); - assertThat(MetadataExchanger.getMeshId( - () -> "{\"node\":[\"id\"]}}")) - .isNull(); - assertThat(MetadataExchanger.getMeshId( - () -> "{\"node\":{\"id\":[\"projects/12/networks/mesh:mine/nodes/uu-id\"]}}")) - .isNull(); - - assertThat(MetadataExchanger.getMeshId( - () -> "{\"NODE\":{\"id\":\"projects/12/networks/mesh:mine/nodes/uu-id\"}}")) - .isNull(); - assertThat(MetadataExchanger.getMeshId( - () -> "{\"node\":{\"ID\":\"projects/12/networks/mesh:mine/nodes/uu-id\"}}")) - .isNull(); - assertThat(MetadataExchanger.getMeshId( - () -> "{\"node\":{\"id\":\"projects/12/networks/mesh:mine\"}}")) - .isNull(); - assertThat(MetadataExchanger.getMeshId( - () -> "{\"node\":{\"id\":\"PROJECTS/12/networks/mesh:mine/nodes/uu-id\"}}")) - .isNull(); - assertThat(MetadataExchanger.getMeshId( - () -> "{\"node\":{\"id\":\"projects/12/NETWORKS/mesh:mine/nodes/uu-id\"}}")) - .isNull(); - assertThat(MetadataExchanger.getMeshId( - () -> "{\"node\":{\"id\":\"projects/12/networks/MESH:mine/nodes/uu-id\"}}")) - .isNull(); - assertThat(MetadataExchanger.getMeshId( - () -> "{\"node\":{\"id\":\"projects/12/networks/mesh:mine/NODES/uu-id\"}}")) - .isNull(); - } @Test public void enablePluginForChannel_matches() { MetadataExchanger exchanger = - new MetadataExchanger(Attributes.builder().build(), (name) -> null, () -> ""); + new MetadataExchanger(Attributes.builder().build(), (name) -> null); assertThat(exchanger.enablePluginForChannel("xds:///testing")).isTrue(); assertThat(exchanger.enablePluginForChannel("xds:/testing")).isTrue(); assertThat(exchanger.enablePluginForChannel( @@ -92,7 +47,7 @@ public void enablePluginForChannel_matches() { @Test public void enablePluginForChannel_doesNotMatch() { MetadataExchanger exchanger = - new MetadataExchanger(Attributes.builder().build(), (name) -> null, () -> ""); + new MetadataExchanger(Attributes.builder().build(), (name) -> null); assertThat(exchanger.enablePluginForChannel("dns:///localhost")).isFalse(); assertThat(exchanger.enablePluginForChannel("xds:///[]")).isFalse(); assertThat(exchanger.enablePluginForChannel("xds://my-xds-server/testing")).isFalse(); @@ -101,7 +56,7 @@ public void enablePluginForChannel_doesNotMatch() { @Test public void addLabels_receivedWrongType() { MetadataExchanger exchanger = - new MetadataExchanger(Attributes.builder().build(), (name) -> null, () -> ""); + new MetadataExchanger(Attributes.builder().build(), (name) -> null); Metadata metadata = new Metadata(); metadata.put(Metadata.Key.of("x-envoy-peer-metadata", Metadata.ASCII_STRING_MARSHALLER), BaseEncoding.base64().encode(Struct.newBuilder() @@ -122,7 +77,7 @@ public void addLabels_receivedWrongType() { @Test public void addLabelsFromExchange_unknownGcpType() { MetadataExchanger exchanger = - new MetadataExchanger(Attributes.builder().build(), (name) -> null, () -> ""); + new MetadataExchanger(Attributes.builder().build(), (name) -> null); Metadata metadata = new Metadata(); metadata.put(Metadata.Key.of("x-envoy-peer-metadata", Metadata.ASCII_STRING_MARSHALLER), BaseEncoding.base64().encode(Struct.newBuilder() @@ -153,8 +108,7 @@ public void addMetadata_k8s() throws Exception { .build(), ImmutableMap.of( "CSM_CANONICAL_SERVICE_NAME", "myservice1", - "CSM_WORKLOAD_NAME", "myworkload1")::get, - () -> ""); + "CSM_WORKLOAD_NAME", "myworkload1")::get); Metadata metadata = new Metadata(); exchanger.newClientCallPlugin().addMetadata(metadata); @@ -182,8 +136,7 @@ public void addMetadata_gce() throws Exception { .build(), ImmutableMap.of( "CSM_CANONICAL_SERVICE_NAME", "myservice1", - "CSM_WORKLOAD_NAME", "myworkload1")::get, - () -> ""); + "CSM_WORKLOAD_NAME", "myworkload1")::get); Metadata metadata = new Metadata(); exchanger.newClientCallPlugin().addMetadata(metadata); diff --git a/gcp-observability/build.gradle b/gcp-observability/build.gradle index f869bd61a76..1d8c7a9f961 100644 --- a/gcp-observability/build.gradle +++ b/gcp-observability/build.gradle @@ -59,7 +59,6 @@ dependencies { project(path: ':grpc-alts', configuration: 'shadow'), project(':grpc-auth'), // Align grpc versions project(':grpc-core'), // Align grpc versions - project(':grpc-grpclb'), // Align grpc versions project(':grpc-services'), // Align grpc versions libraries.animalsniffer.annotations, // Use our newer version libraries.auto.value.annotations, // Use our newer version @@ -74,7 +73,11 @@ dependencies { exclude group: 'junit', module: 'junit' } - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } configureProtoCompilation() diff --git a/gcp-observability/interop/build.gradle b/gcp-observability/interop/build.gradle index 4a78c056eac..7e17624995a 100644 --- a/gcp-observability/interop/build.gradle +++ b/gcp-observability/interop/build.gradle @@ -10,7 +10,11 @@ dependencies { implementation project(':grpc-interop-testing'), project(':grpc-gcp-observability') - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } application { diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java index 497a1eda30f..7fe4e3a8a3c 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java @@ -127,6 +127,15 @@ static GcpObservability grpcInit( /** Un-initialize/shutdown grpc-observability. */ @Override public void close() { + closeWithSleepTime(2 * METRICS_EXPORT_INTERVAL, TimeUnit.SECONDS); + } + + /** + * Method to close along with sleep time explicitly. + * + * @param sleepTime sleepTime + */ + void closeWithSleepTime(long sleepTime, TimeUnit timeUnit) { synchronized (GcpObservability.class) { if (instance == null) { throw new IllegalStateException("GcpObservability already closed!"); @@ -135,8 +144,7 @@ public void close() { if (config.isEnableCloudMonitoring() || config.isEnableCloudTracing()) { try { // Sleeping before shutdown to ensure all metrics and traces are flushed - Thread.sleep( - TimeUnit.MILLISECONDS.convert(2 * METRICS_EXPORT_INTERVAL, TimeUnit.SECONDS)); + timeUnit.sleep(sleepTime); } catch (InterruptedException e) { Thread.currentThread().interrupt(); logger.log(Level.SEVERE, "Caught exception during sleep", e); diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/GcpObservabilityTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/GcpObservabilityTest.java index 40f2fb01490..25467839dd6 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/GcpObservabilityTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/GcpObservabilityTest.java @@ -45,6 +45,7 @@ import io.opencensus.trace.samplers.Samplers; import java.io.IOException; import java.util.List; +import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; import org.junit.Test; import org.junit.runner.RunWith; @@ -196,9 +197,9 @@ public void run() { mock(InternalLoggingServerInterceptor.Factory.class); when(serverInterceptorFactory.create()).thenReturn(serverInterceptor); - try (GcpObservability unused = - GcpObservability.grpcInit( - sink, config, channelInterceptorFactory, serverInterceptorFactory)) { + try { + GcpObservability gcpObservability = GcpObservability.grpcInit( + sink, config, channelInterceptorFactory, serverInterceptorFactory); List configurators = InternalConfiguratorRegistry.getConfigurators(); assertThat(configurators).hasSize(1); ObservabilityConfigurator configurator = (ObservabilityConfigurator) configurators.get(0); @@ -208,9 +209,11 @@ public void run() { assertThat(list.get(2)).isInstanceOf(ConditionalClientInterceptor.class); assertThat(configurator.serverInterceptors).hasSize(1); assertThat(configurator.tracerFactories).hasSize(2); + gcpObservability.closeWithSleepTime(3000, TimeUnit.MILLISECONDS); } catch (Exception e) { fail("Encountered exception: " + e); } + verify(sink).close(); } } diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java index ee711cad097..92e67b01e01 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java @@ -73,7 +73,7 @@ public class LoggingTest { /** * Cloud logging test using global interceptors. * - *

Ignoring test, because it calls external Cloud Logging APIs. + *

Ignoring test, because it calls external Cloud Logging APIs. * To test cloud logging setup locally, * 1. Set up Cloud auth credentials * 2. Assign permissions to service account to write logs to project specified by diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java index a9e0d6e2235..f409a149bf1 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java @@ -108,8 +108,7 @@ public class ObservabilityConfigImplTest { private static final String PROJECT_ID = "{\n" + " \"project_id\": \"grpc-testing\",\n" - + " \"cloud_logging\": {},\n" - + " \"project_id\": \"grpc-testing\"\n" + + " \"cloud_logging\": {}\n" + "}"; private static final String EMPTY_CONFIG = "{}"; diff --git a/googleapis/BUILD.bazel b/googleapis/BUILD.bazel index 9ce4179f7b7..5b62b21cb3a 100644 --- a/googleapis/BUILD.bazel +++ b/googleapis/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( @@ -12,5 +13,6 @@ java_library( "//core:internal", "//xds", artifact("com.google.guava:guava"), + artifact("com.google.errorprone:error_prone_annotations"), ], ) diff --git a/googleapis/build.gradle b/googleapis/build.gradle index 435e552d47d..3a7a3a2766a 100644 --- a/googleapis/build.gradle +++ b/googleapis/build.gradle @@ -21,5 +21,9 @@ dependencies { libraries.guava.jre // JRE required by transitive protobuf-java-util testImplementation testFixtures(project(':grpc-core')) - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } diff --git a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdExperimentalNameResolverProvider.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdExperimentalNameResolverProvider.java index 349e1c94380..db674aeb2ee 100644 --- a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdExperimentalNameResolverProvider.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdExperimentalNameResolverProvider.java @@ -20,6 +20,7 @@ import io.grpc.NameResolver; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; +import io.grpc.Uri; import java.net.URI; /** @@ -35,6 +36,11 @@ public NameResolver newNameResolver(URI targetUri, Args args) { return delegate.newNameResolver(targetUri, args); } + @Override + public NameResolver newNameResolver(Uri targetUri, Args args) { + return delegate.newNameResolver(targetUri, args); + } + @Override public String getDefaultScheme() { return delegate.getDefaultScheme(); diff --git a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java index ebc7dd05ea4..427c0658531 100644 --- a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java @@ -20,18 +20,27 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; -import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.io.CharStreams; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.NameResolverRegistry; import io.grpc.Status; import io.grpc.SynchronizationContext; +import io.grpc.Uri; import io.grpc.alts.InternalCheckGcpEnvironment; import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.xds.InternalGrpcBootstrapperImpl; +import io.grpc.xds.InternalSharedXdsClientPoolProvider; +import io.grpc.xds.InternalSharedXdsClientPoolProvider.XdsClientResult; +import io.grpc.xds.XdsNameResolverProvider; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsInitializationException; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; @@ -41,7 +50,7 @@ import java.net.URISyntaxException; import java.net.URL; import java.nio.charset.StandardCharsets; -import java.util.Map; +import java.util.List; import java.util.Random; import java.util.concurrent.Executor; import java.util.logging.Level; @@ -63,52 +72,55 @@ final class GoogleCloudToProdNameResolver extends NameResolver { static final String C2P_AUTHORITY = "traffic-director-c2p.xds.googleapis.com"; @VisibleForTesting static boolean isOnGcp = InternalCheckGcpEnvironment.isOnGcp(); - @VisibleForTesting - static boolean xdsBootstrapProvided = - System.getenv("GRPC_XDS_BOOTSTRAP") != null - || System.getProperty("io.grpc.xds.bootstrap") != null - || System.getenv("GRPC_XDS_BOOTSTRAP_CONFIG") != null - || System.getProperty("io.grpc.xds.bootstrapConfig") != null; - @VisibleForTesting - static boolean enableFederation = - Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_XDS_FEDERATION")) - || Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_XDS_FEDERATION")); private static final String serverUriOverride = System.getenv("GRPC_TEST_ONLY_GOOGLE_C2P_RESOLVER_TRAFFIC_DIRECTOR_URI"); - private HttpConnectionProvider httpConnectionProvider = HttpConnectionFactory.INSTANCE; + @GuardedBy("GoogleCloudToProdNameResolver.class") + private static BootstrapInfo bootstrapInfo; + private static HttpConnectionProvider httpConnectionProvider = HttpConnectionFactory.INSTANCE; + private static int c2pId = new Random().nextInt(); + + private static synchronized BootstrapInfo getBootstrapInfo() + throws XdsInitializationException, IOException { + if (bootstrapInfo != null) { + return bootstrapInfo; + } + BootstrapInfo bootstrapInfoTmp = + InternalGrpcBootstrapperImpl.parseBootstrap(generateBootstrap()); + // Avoid setting global when testing + if (httpConnectionProvider == HttpConnectionFactory.INSTANCE) { + bootstrapInfo = bootstrapInfoTmp; + } + return bootstrapInfoTmp; + } + private final String authority; private final SynchronizationContext syncContext; private final Resource executorResource; - private final BootstrapSetter bootstrapSetter; + private final String target; + private final MetricRecorder metricRecorder; private final NameResolver delegate; - private final Random rand; private final boolean usingExecutorResource; - // It's not possible to use both PSM and DirectPath C2P in the same application. - // Delegate to DNS if user-provided bootstrap is found. - private final String schemeOverride = - !isOnGcp - || (xdsBootstrapProvided && !enableFederation) - ? "dns" : "xds"; + private final String schemeOverride = !isOnGcp ? "dns" : "xds"; + private XdsClientResult xdsClientPool; + private XdsClient xdsClient; private Executor executor; private Listener2 listener; private boolean succeeded; private boolean resolving; private boolean shutdown; - GoogleCloudToProdNameResolver(URI targetUri, Args args, Resource executorResource, - BootstrapSetter bootstrapSetter) { - this(targetUri, args, executorResource, new Random(), bootstrapSetter, + GoogleCloudToProdNameResolver(URI targetUri, Args args, Resource executorResource) { + this(targetUri, args, executorResource, NameResolverRegistry.getDefaultRegistry().asFactory()); } + // TODO(jdcormie): Remove after io.grpc.Uri migration. @VisibleForTesting GoogleCloudToProdNameResolver(URI targetUri, Args args, Resource executorResource, - Random rand, BootstrapSetter bootstrapSetter, NameResolver.Factory nameResolverFactory) { + NameResolver.Factory nameResolverFactory) { this.executorResource = checkNotNull(executorResource, "executorResource"); - this.bootstrapSetter = checkNotNull(bootstrapSetter, "bootstrapSetter"); - this.rand = checkNotNull(rand, "rand"); String targetPath = checkNotNull(checkNotNull(targetUri, "targetUri").getPath(), "targetPath"); Preconditions.checkArgument( targetPath.startsWith("/"), @@ -118,15 +130,59 @@ final class GoogleCloudToProdNameResolver extends NameResolver { authority = GrpcUtil.checkAuthority(targetPath.substring(1)); syncContext = checkNotNull(args, "args").getSynchronizationContext(); targetUri = overrideUriScheme(targetUri, schemeOverride); - if (schemeOverride.equals("xds") && enableFederation) { + if (schemeOverride.equals("xds")) { targetUri = overrideUriAuthority(targetUri, C2P_AUTHORITY); + args = args.toBuilder() + .setArg(XdsNameResolverProvider.XDS_CLIENT_SUPPLIER, () -> xdsClient) + .build(); } + target = targetUri.toString(); + metricRecorder = args.getMetricRecorder(); delegate = checkNotNull(nameResolverFactory, "nameResolverFactory").newNameResolver( targetUri, args); executor = args.getOffloadExecutor(); usingExecutorResource = executor == null; } + GoogleCloudToProdNameResolver(Uri targetUri, Args args, Resource executorResource) { + this(targetUri, args, executorResource, NameResolverRegistry.getDefaultRegistry().asFactory()); + } + + @VisibleForTesting + GoogleCloudToProdNameResolver( + Uri targetUri, + Args args, + Resource executorResource, + NameResolver.Factory nameResolverFactory) { + this.executorResource = checkNotNull(executorResource, "executorResource"); + Preconditions.checkArgument( + targetUri.isPathAbsolute(), + "the path component of the target (%s) must start with '/'", + targetUri); + List pathSegments = targetUri.getPathSegments(); + Preconditions.checkArgument( + pathSegments.size() == 1, + "the path component of the target (%s) must have exactly one segment", + targetUri); + authority = GrpcUtil.checkAuthority(pathSegments.get(0)); + syncContext = checkNotNull(args, "args").getSynchronizationContext(); + Uri.Builder modifiedTargetBuilder = targetUri.toBuilder().setScheme(schemeOverride); + if (schemeOverride.equals("xds")) { + modifiedTargetBuilder.setRawAuthority(C2P_AUTHORITY); + args = + args.toBuilder() + .setArg(XdsNameResolverProvider.XDS_CLIENT_SUPPLIER, () -> xdsClient) + .build(); + } + targetUri = modifiedTargetBuilder.build(); + target = targetUri.toString(); + metricRecorder = args.getMetricRecorder(); + delegate = + checkNotNull(nameResolverFactory, "nameResolverFactory").newNameResolver(targetUri, args); + executor = args.getOffloadExecutor(); + usingExecutorResource = executor == null; + } + @Override public String getServiceAuthority() { return authority; @@ -150,7 +206,7 @@ private void resolve() { resolving = true; if (logger.isLoggable(Level.FINE)) { - logger.fine("resolve with schemaOverride = " + schemeOverride); + logger.log(Level.FINE, "start with schemaOverride = {0}", schemeOverride); } if (schemeOverride.equals("dns")) { @@ -168,28 +224,28 @@ private void resolve() { class Resolve implements Runnable { @Override public void run() { - ImmutableMap rawBootstrap = null; + BootstrapInfo bootstrapInfo = null; try { - // User provided bootstrap configs are only supported with federation. If federation is - // not enabled or there is no user provided config, we set a custom bootstrap override. - // Otherwise, we don't set the override, which will allow a user provided bootstrap config - // to take effect. - if (!enableFederation || !xdsBootstrapProvided) { - rawBootstrap = generateBootstrap(queryZoneMetadata(METADATA_URL_ZONE), - queryIpv6SupportMetadata(METADATA_URL_SUPPORT_IPV6)); - } + bootstrapInfo = getBootstrapInfo(); } catch (IOException e) { listener.onError( Status.INTERNAL.withDescription("Unable to get metadata").withCause(e)); + } catch (XdsInitializationException e) { + listener.onError( + Status.INTERNAL.withDescription("Unable to create c2p bootstrap").withCause(e)); + } catch (Throwable t) { + listener.onError( + Status.INTERNAL.withDescription("Unexpected error creating c2p bootstrap") + .withCause(t)); } finally { - final ImmutableMap finalRawBootstrap = rawBootstrap; + final BootstrapInfo finalBootstrapInfo = bootstrapInfo; syncContext.execute(new Runnable() { @Override public void run() { - if (!shutdown) { - if (finalRawBootstrap != null) { - bootstrapSetter.setBootstrap(finalRawBootstrap); - } + if (!shutdown && finalBootstrapInfo != null) { + xdsClientPool = InternalSharedXdsClientPoolProvider.getOrCreate( + target, finalBootstrapInfo, metricRecorder, null); + xdsClient = xdsClientPool.getObject(); delegate.start(listener); succeeded = true; } @@ -203,9 +259,16 @@ public void run() { executor.execute(new Resolve()); } - private ImmutableMap generateBootstrap(String zone, boolean supportIpv6) { + @VisibleForTesting + static ImmutableMap generateBootstrap() throws IOException { + return generateBootstrap( + queryZoneMetadata(METADATA_URL_ZONE), + queryIpv6SupportMetadata(METADATA_URL_SUPPORT_IPV6)); + } + + private static ImmutableMap generateBootstrap(String zone, boolean supportIpv6) { ImmutableMap.Builder nodeBuilder = ImmutableMap.builder(); - nodeBuilder.put("id", "C2P-" + (rand.nextInt() & Integer.MAX_VALUE)); + nodeBuilder.put("id", "C2P-" + (c2pId & Integer.MAX_VALUE)); if (!zone.isEmpty()) { nodeBuilder.put("locality", ImmutableMap.of("zone", zone)); } @@ -250,12 +313,15 @@ public void shutdown() { if (delegate != null) { delegate.shutdown(); } + if (xdsClient != null) { + xdsClient = xdsClientPool.returnObject(xdsClient); + } if (executor != null && usingExecutorResource) { executor = SharedResourceHolder.release(executorResource, executor); } } - private String queryZoneMetadata(String url) throws IOException { + private static String queryZoneMetadata(String url) throws IOException { HttpURLConnection con = null; String respBody; try { @@ -275,7 +341,7 @@ private String queryZoneMetadata(String url) throws IOException { return index == -1 ? "" : respBody.substring(index + 1); } - private boolean queryIpv6SupportMetadata(String url) throws IOException { + private static boolean queryIpv6SupportMetadata(String url) throws IOException { HttpURLConnection con = null; try { con = httpConnectionProvider.createConnection(url); @@ -294,8 +360,17 @@ private boolean queryIpv6SupportMetadata(String url) throws IOException { } @VisibleForTesting - void setHttpConnectionProvider(HttpConnectionProvider httpConnectionProvider) { - this.httpConnectionProvider = httpConnectionProvider; + static void setHttpConnectionProvider(HttpConnectionProvider httpConnectionProvider) { + if (httpConnectionProvider == null) { + GoogleCloudToProdNameResolver.httpConnectionProvider = HttpConnectionFactory.INSTANCE; + } else { + GoogleCloudToProdNameResolver.httpConnectionProvider = httpConnectionProvider; + } + } + + @VisibleForTesting + static void setC2pId(int c2pId) { + GoogleCloudToProdNameResolver.c2pId = c2pId; } private static URI overrideUriScheme(URI uri, String scheme) { @@ -335,8 +410,4 @@ public HttpURLConnection createConnection(String url) throws IOException { interface HttpConnectionProvider { HttpURLConnection createConnection(String url) throws IOException; } - - public interface BootstrapSetter { - void setBootstrap(Map bootstrap); - } } diff --git a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java index 8ad292a3d98..f936de086e9 100644 --- a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java @@ -21,14 +21,13 @@ import io.grpc.NameResolver; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; +import io.grpc.Uri; import io.grpc.internal.GrpcUtil; -import io.grpc.xds.InternalSharedXdsClientPoolProvider; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; import java.util.Collection; import java.util.Collections; -import java.util.Map; /** * A provider for {@link GoogleCloudToProdNameResolver}. @@ -48,12 +47,21 @@ public GoogleCloudToProdNameResolverProvider() { this.scheme = Preconditions.checkNotNull(scheme, "scheme"); } + // TODO(jdcormie): Remove after io.grpc.Uri migration is complete. @Override public NameResolver newNameResolver(URI targetUri, Args args) { if (scheme.equals(targetUri.getScheme())) { return new GoogleCloudToProdNameResolver( - targetUri, args, GrpcUtil.SHARED_CHANNEL_EXECUTOR, - new SharedXdsClientPoolProviderBootstrapSetter()); + targetUri, args, GrpcUtil.SHARED_CHANNEL_EXECUTOR); + } + return null; + } + + @Override + public NameResolver newNameResolver(Uri targetUri, Args args) { + if (scheme.equals(targetUri.getScheme())) { + return new GoogleCloudToProdNameResolver( + targetUri, args, GrpcUtil.SHARED_CHANNEL_EXECUTOR); } return null; } @@ -77,12 +85,4 @@ protected int priority() { public Collection> getProducedSocketAddressTypes() { return Collections.singleton(InetSocketAddress.class); } - - private static final class SharedXdsClientPoolProviderBootstrapSetter - implements GoogleCloudToProdNameResolver.BootstrapSetter { - @Override - public void setBootstrap(Map bootstrap) { - InternalSharedXdsClientPoolProvider.setDefaultProviderBootstrapOverride(bootstrap); - } - } } diff --git a/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProviderTest.java b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProviderTest.java index 447b102c8c7..39468472985 100644 --- a/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProviderTest.java +++ b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProviderTest.java @@ -23,20 +23,23 @@ import io.grpc.ChannelLogger; import io.grpc.InternalServiceProviders; import io.grpc.NameResolver; +import io.grpc.NameResolver.Args; import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.NameResolverProvider; import io.grpc.SynchronizationContext; +import io.grpc.Uri; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; import java.net.URI; +import java.util.Arrays; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; -/** - * Unit tests for {@link GoogleCloudToProdNameResolverProvider}. - */ -@RunWith(JUnit4.class) +/** Unit tests for {@link GoogleCloudToProdNameResolverProvider}. */ +@RunWith(Parameterized.class) public class GoogleCloudToProdNameResolverProviderTest { private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @@ -59,6 +62,13 @@ public void uncaughtException(Thread t, Throwable e) { private GoogleCloudToProdNameResolverProvider provider = new GoogleCloudToProdNameResolverProvider(); + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + @Test public void provided() { for (NameResolverProvider current @@ -84,16 +94,24 @@ NameResolverProvider.class, getClass().getClassLoader())) { } @Test - public void newNameResolver() { - assertThat(provider - .newNameResolver(URI.create("google-c2p:///foo.googleapis.com"), args)) + public void shouldProvideNameResolverOfExpectedType() { + assertThat(newNameResolver(provider, "google-c2p:///foo.googleapis.com", args)) .isInstanceOf(GoogleCloudToProdNameResolver.class); } @Test - public void experimentalNewNameResolver() { - assertThat(new GoogleCloudToProdExperimentalNameResolverProvider() - .newNameResolver(URI.create("google-c2p-experimental:///foo.googleapis.com"), args)) + public void shouldProvideExperimentalNameResolverOfExpectedType() { + assertThat( + newNameResolver( + new GoogleCloudToProdExperimentalNameResolverProvider(), + "google-c2p-experimental:///foo.googleapis.com", + args)) .isInstanceOf(GoogleCloudToProdNameResolver.class); } + + private NameResolver newNameResolver(NameResolverProvider provider, String uri, Args args) { + return enableRfc3986UrisParam + ? provider.newNameResolver(Uri.create(uri), args) + : provider.newNameResolver(URI.create(uri), args); + } } diff --git a/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java index edb3126d1e3..d3d3cfc4bff 100644 --- a/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java +++ b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import io.grpc.ChannelLogger; +import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.NameResolver.Args; import io.grpc.NameResolver.ServiceConfigParser; @@ -33,6 +34,7 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; +import io.grpc.Uri; import io.grpc.googleapis.GoogleCloudToProdNameResolver.HttpConnectionProvider; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; @@ -42,31 +44,33 @@ import java.net.HttpURLConnection; import java.net.URI; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class GoogleCloudToProdNameResolverTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - private static final URI TARGET_URI = URI.create("google-c2p:///googleapis.com"); + private static final String TARGET_URI = "google-c2p:///googleapis.com"; private static final String ZONE = "us-central1-a"; private static final int DEFAULT_PORT = 887; @@ -77,15 +81,16 @@ public void uncaughtException(Thread t, Throwable e) { throw new AssertionError(e); } }); + private final FakeClock fakeExecutor = new FakeClock(); private final NameResolver.Args args = NameResolver.Args.newBuilder() .setDefaultPort(DEFAULT_PORT) .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR) .setSynchronizationContext(syncContext) + .setScheduledExecutorService(fakeExecutor.getScheduledExecutorService()) .setServiceConfigParser(mock(ServiceConfigParser.class)) .setChannelLogger(mock(ChannelLogger.class)) + .setMetricRecorder(new MetricRecorder() {}) .build(); - private final FakeClock fakeExecutor = new FakeClock(); - private final FakeBootstrapSetter fakeBootstrapSetter = new FakeBootstrapSetter(); private final Resource fakeExecutorResource = new Resource() { @Override public Executor create() { @@ -101,34 +106,25 @@ public void close(Executor instance) {} @Mock private NameResolver.Listener2 mockListener; - private Random random = new Random(1); @Captor private ArgumentCaptor errorCaptor; private boolean originalIsOnGcp; - private boolean originalXdsBootstrapProvided; private GoogleCloudToProdNameResolver resolver; + private String responseToIpV6 = "1:1:1"; + + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; @Before public void setUp() { nsRegistry.register(new FakeNsProvider("dns")); nsRegistry.register(new FakeNsProvider("xds")); originalIsOnGcp = GoogleCloudToProdNameResolver.isOnGcp; - originalXdsBootstrapProvided = GoogleCloudToProdNameResolver.xdsBootstrapProvided; - } - - @After - public void tearDown() { - GoogleCloudToProdNameResolver.isOnGcp = originalIsOnGcp; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = originalXdsBootstrapProvided; - resolver.shutdown(); - verify(Iterables.getOnlyElement(delegatedResolver.values())).shutdown(); - } - private void createResolver() { - createResolver("1:1:1"); - } - - private void createResolver(String responseToIpV6) { HttpConnectionProvider httpConnections = new HttpConnectionProvider() { @Override public HttpURLConnection createConnection(String url) throws IOException { @@ -148,10 +144,28 @@ public HttpURLConnection createConnection(String url) throws IOException { throw new AssertionError("Unknown http query"); } }; - resolver = new GoogleCloudToProdNameResolver( - TARGET_URI, args, fakeExecutorResource, random, fakeBootstrapSetter, - nsRegistry.asFactory()); - resolver.setHttpConnectionProvider(httpConnections); + GoogleCloudToProdNameResolver.setHttpConnectionProvider(httpConnections); + + GoogleCloudToProdNameResolver.setC2pId(new Random(1).nextInt()); + } + + @After + public void tearDown() { + GoogleCloudToProdNameResolver.isOnGcp = originalIsOnGcp; + GoogleCloudToProdNameResolver.setHttpConnectionProvider(null); + if (resolver != null) { + resolver.shutdown(); + verify(Iterables.getOnlyElement(delegatedResolver.values())).shutdown(); + } + } + + private void createResolver() { + resolver = + enableRfc3986UrisParam + ? new GoogleCloudToProdNameResolver( + Uri.create(TARGET_URI), args, fakeExecutorResource, nsRegistry.asFactory()) + : new GoogleCloudToProdNameResolver( + URI.create(TARGET_URI), args, fakeExecutorResource, nsRegistry.asFactory()); } @Test @@ -164,27 +178,19 @@ public void notOnGcp_DelegateToDns() { } @Test - public void hasProvidedBootstrap_DelegateToDns() { + public void onGcpAndNoProvidedBootstrap_DelegateToXds() { GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = true; - GoogleCloudToProdNameResolver.enableFederation = false; createResolver(); resolver.start(mockListener); - assertThat(delegatedResolver.keySet()).containsExactly("dns"); + fakeExecutor.runDueTasks(); + assertThat(delegatedResolver.keySet()).containsExactly("xds"); verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); } @SuppressWarnings("unchecked") @Test - public void onGcpAndNoProvidedBootstrap_DelegateToXds() { - GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = false; - createResolver(); - resolver.start(mockListener); - fakeExecutor.runDueTasks(); - assertThat(delegatedResolver.keySet()).containsExactly("xds"); - verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); - Map bootstrap = fakeBootstrapSetter.bootstrapRef.get(); + public void generateBootstrap_ipv6() throws IOException { + Map bootstrap = GoogleCloudToProdNameResolver.generateBootstrap(); Map node = (Map) bootstrap.get("node"); assertThat(node).containsExactly( "id", "C2P-991614323", @@ -204,15 +210,9 @@ public void onGcpAndNoProvidedBootstrap_DelegateToXds() { @SuppressWarnings("unchecked") @Test - public void onGcpAndNoProvidedBootstrap_DelegateToXds_noIpV6() { - GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = false; - createResolver(null); - resolver.start(mockListener); - fakeExecutor.runDueTasks(); - assertThat(delegatedResolver.keySet()).containsExactly("xds"); - verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); - Map bootstrap = fakeBootstrapSetter.bootstrapRef.get(); + public void generateBootstrap_noIpV6() throws IOException { + responseToIpV6 = null; + Map bootstrap = GoogleCloudToProdNameResolver.generateBootstrap(); Map node = (Map) bootstrap.get("node"); assertThat(node).containsExactly( "id", "C2P-991614323", @@ -231,70 +231,18 @@ public void onGcpAndNoProvidedBootstrap_DelegateToXds_noIpV6() { @SuppressWarnings("unchecked") @Test - public void emptyResolverMeetadataValue() { - GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = false; - createResolver(""); - resolver.start(mockListener); - fakeExecutor.runDueTasks(); - assertThat(delegatedResolver.keySet()).containsExactly("xds"); - verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); - Map bootstrap = fakeBootstrapSetter.bootstrapRef.get(); + public void emptyResolverMeetadataValue() throws IOException { + responseToIpV6 = ""; + Map bootstrap = GoogleCloudToProdNameResolver.generateBootstrap(); Map node = (Map) bootstrap.get("node"); assertThat(node).containsExactly( "id", "C2P-991614323", "locality", ImmutableMap.of("zone", ZONE)); } - @SuppressWarnings("unchecked") - @Test - public void onGcpAndNoProvidedBootstrapAndFederationEnabled_DelegateToXds() { - GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = false; - GoogleCloudToProdNameResolver.enableFederation = true; - createResolver(); - resolver.start(mockListener); - fakeExecutor.runDueTasks(); - assertThat(delegatedResolver.keySet()).containsExactly("xds"); - verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); - // check bootstrap - Map bootstrap = fakeBootstrapSetter.bootstrapRef.get(); - Map node = (Map) bootstrap.get("node"); - assertThat(node).containsExactly( - "id", "C2P-991614323", - "locality", ImmutableMap.of("zone", ZONE), - "metadata", ImmutableMap.of("TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE", true)); - Map server = Iterables.getOnlyElement( - (List>) bootstrap.get("xds_servers")); - assertThat(server).containsExactly( - "server_uri", "directpath-pa.googleapis.com", - "channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default")), - "server_features", ImmutableList.of("xds_v3", "ignore_resource_deletion")); - Map authorities = (Map) bootstrap.get("authorities"); - assertThat(authorities).containsExactly( - "traffic-director-c2p.xds.googleapis.com", - ImmutableMap.of("xds_servers", ImmutableList.of(server))); - } - - @SuppressWarnings("unchecked") - @Test - public void onGcpAndProvidedBootstrapAndFederationEnabled_DontDelegateToXds() { - GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = true; - GoogleCloudToProdNameResolver.enableFederation = true; - createResolver(); - resolver.start(mockListener); - fakeExecutor.runDueTasks(); - assertThat(delegatedResolver.keySet()).containsExactly("xds"); - verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); - // Bootstrapper should not have been set, since there was no user provided config. - assertThat(fakeBootstrapSetter.bootstrapRef.get()).isNull(); - } - @Test public void failToQueryMetadata() { GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = false; createResolver(); HttpConnectionProvider httpConnections = new HttpConnectionProvider() { @Override @@ -304,7 +252,7 @@ public HttpURLConnection createConnection(String url) throws IOException { return con; } }; - resolver.setHttpConnectionProvider(httpConnections); + GoogleCloudToProdNameResolver.setHttpConnectionProvider(httpConnections); resolver.start(mockListener); fakeExecutor.runDueTasks(); verify(mockListener).onError(errorCaptor.capture()); @@ -344,14 +292,4 @@ public String getDefaultScheme() { return scheme; } } - - private static final class FakeBootstrapSetter - implements GoogleCloudToProdNameResolver.BootstrapSetter { - private final AtomicReference> bootstrapRef = new AtomicReference<>(); - - @Override - public void setBootstrap(Map bootstrap) { - bootstrapRef.set(bootstrap); - } - } } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 488ead9ad86..705026a3fe3 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,103 +1,153 @@ [versions] -netty = '4.1.110.Final' -# Keep the following references of tcnative version in sync whenever it's updated: -# SECURITY.md -nettytcnative = '2.0.65.Final' opencensus = "0.31.1" -# Not upgrading to 4.x as it is not yet ABI compatible. -# https://github.com/protocolbuffers/protobuf/issues/17247 -protobuf = "3.25.3" [libraries] android-annotations = "com.google.android:annotations:4.1.1.4" -androidx-annotation = "androidx.annotation:annotation:1.8.0" +# 1.9.1+ uses Kotlin and requires Android Gradle Plugin 9+ +# checkForUpdates: androidx-annotation:1.9.0 +androidx-annotation = "androidx.annotation:annotation:1.9.0" +# 1.14.x doesn't exist. +# 1.15.0+ requires compileSdkVersion 35 which officially requires AGP 8.6.0+. +# It might work before then, but AGP 7.4.1 fails with: +# RES_TABLE_TYPE_TYPE entry offsets overlap actual entry data. +# 1.16.0+ requires AGP 8.6.0+ +# checkForUpdates: androidx-core:1.13.+ androidx-core = "androidx.core:core:1.13.1" -androidx-lifecycle-common = "androidx.lifecycle:lifecycle-common:2.8.3" -androidx-lifecycle-service = "androidx.lifecycle:lifecycle-service:2.8.3" -androidx-test-core = "androidx.test:core:1.6.1" -androidx-test-ext-junit = "androidx.test.ext:junit:1.2.1" -androidx-test-rules = "androidx.test:rules:1.6.1" -animalsniffer = "org.codehaus.mojo:animal-sniffer:1.24" -animalsniffer-annotations = "org.codehaus.mojo:animal-sniffer-annotations:1.24" -assertj-core = "org.assertj:assertj-core:3.26.0" +# 2.9+ requires AGP 8.1.1+ +# checkForUpdates: androidx-lifecycle-common:2.8.+ +androidx-lifecycle-common = "androidx.lifecycle:lifecycle-common:2.8.7" +# checkForUpdates: androidx-lifecycle-service:2.8.+ +androidx-lifecycle-service = "androidx.lifecycle:lifecycle-service:2.8.7" +androidx-test-core = "androidx.test:core:1.7.0" +androidx-test-ext-junit = "androidx.test.ext:junit:1.3.0" +androidx-test-rules = "androidx.test:rules:1.7.0" +androidx-test-runner = "androidx.test:runner:1.7.0" +animalsniffer = "org.codehaus.mojo:animal-sniffer:1.27" +animalsniffer-annotations = "org.codehaus.mojo:animal-sniffer-annotations:1.27" +assertj-core = "org.assertj:assertj-core:3.27.7" +# 1.11.1 started converting jsr305 @Nullable to jspecify +# checkForUpdates: auto-value:1.11.0 auto-value = "com.google.auto.value:auto-value:1.11.0" +# checkForUpdates: auto-value-annotations:1.11.0 auto-value-annotations = "com.google.auto.value:auto-value-annotations:1.11.0" -checkstyle = "com.puppycrawl.tools:checkstyle:10.17.0" +# 11.0+ requires Java 17+ +# https://checkstyle.sourceforge.io/releasenotes.html +# checkForUpdates: checkstyle:10.+ +checkstyle = "com.puppycrawl.tools:checkstyle:10.26.1" +# checkstyle 10.0+ requires Java 11+ +# See https://checkstyle.sourceforge.io/releasenotes_old_8-35_10-26.html#Release_10.0 +# checkForUpdates: checkstylejava8:9.+ +checkstylejava8 = "com.puppycrawl.tools:checkstyle:9.3" commons-math3 = "org.apache.commons:commons-math3:3.6.1" conscrypt = "org.conscrypt:conscrypt-openjdk-uber:2.5.2" +# 141.7340.3+ requires Java 17+ +# checkForUpdates: cronet-api:119.6045.31 cronet-api = "org.chromium.net:cronet-api:119.6045.31" +# checkForUpdates: cronet-embedded:119.6045.31 cronet-embedded = "org.chromium.net:cronet-embedded:119.6045.31" -errorprone-annotations = "com.google.errorprone:error_prone_annotations:2.28.0" -errorprone-core = "com.google.errorprone:error_prone_core:2.28.0" -google-api-protos = "com.google.api.grpc:proto-google-common-protos:2.41.0" -google-auth-credentials = "com.google.auth:google-auth-library-credentials:1.23.0" -google-auth-oauth2Http = "com.google.auth:google-auth-library-oauth2-http:1.23.0" +errorprone-annotations = "com.google.errorprone:error_prone_annotations:2.48.0" +# 2.32.0+ requires Java 17+ +# checkForUpdates: errorprone-core:2.31.+ +errorprone-core = "com.google.errorprone:error_prone_core:2.31.0" +# 2.11.0+ requires JDK 11+ (See https://github.com/google/error-prone/releases/tag/v2.11.0) +# checkForUpdates: errorprone-corejava8:2.10.+ +errorprone-corejava8 = "com.google.errorprone:error_prone_core:2.10.0" +# 2.65.0+ requires protobuf 4.x +# checkForUpdates: google-api-protos:2.64.+ +google-api-protos = "com.google.api.grpc:proto-google-common-protos:2.64.1" +# 1.43.0+ versions of google-auth-library requires protobuf 4.x +# checkForUpdates: google-auth-credentials:1.42.+ +google-auth-credentials = "com.google.auth:google-auth-library-credentials:1.42.1" +# checkForUpdates: google-auth-oauth2Http:1.42.+ +google-auth-oauth2Http = "com.google.auth:google-auth-library-oauth2-http:1.42.1" # Release notes: https://cloud.google.com/logging/docs/release-notes -google-cloud-logging = "com.google.cloud:google-cloud-logging:3.19.0" -gson = "com.google.code.gson:gson:2.11.0" -guava = "com.google.guava:guava:33.2.1-android" +# 3.23.11+ require protobuf 4.x +# checkForUpdates: google-cloud-logging:3.23.10 +google-cloud-logging = "com.google.cloud:google-cloud-logging:3.23.10" +gson = "com.google.code.gson:gson:2.13.2" +guava = "com.google.guava:guava:33.5.0-android" guava-betaChecker = "com.google.guava:guava-beta-checker:1.0" -guava-testlib = "com.google.guava:guava-testlib:33.2.1-android" +guava-testlib = "com.google.guava:guava-testlib:33.5.0-android" # JRE version is needed for projects where its a transitive dependency, f.e. gcp-observability. # May be different from the -android version. -guava-jre = "com.google.guava:guava:33.2.1-jre" +guava-jre = "com.google.guava:guava:33.5.0-jre" hdrhistogram = "org.hdrhistogram:HdrHistogram:2.2.2" +# 6.0.0+ use java.lang.Deprecated forRemoval and since from Java 9 +# checkForUpdates: jakarta-servlet-api:5.+ jakarta-servlet-api = "jakarta.servlet:jakarta.servlet-api:5.0.0" -javax-annotation = "org.apache.tomcat:annotations-api:6.0.53" javax-servlet-api = "javax.servlet:javax.servlet-api:4.0.1" -jetty-client = "org.eclipse.jetty:jetty-client:10.0.20" -jetty-http2-server = "org.eclipse.jetty.http2:http2-server:11.0.22" -jetty-http2-server10 = "org.eclipse.jetty.http2:http2-server:10.0.20" -jetty-servlet = "org.eclipse.jetty:jetty-servlet:11.0.22" -jetty-servlet10 = "org.eclipse.jetty:jetty-servlet:10.0.20" +# 12.0.0+ require Java 17+ +# checkForUpdates: jetty-client:11.+ +jetty-client = "org.eclipse.jetty:jetty-client:11.0.26" +jetty-http2-server = "org.eclipse.jetty.http2:jetty-http2-server:12.1.7" +# 10.0.25+ uses uses @Deprecated(since=/forRemoval=) from Java 9 +# checkForUpdates: jetty-http2-server10:10.0.24 +jetty-http2-server10 = "org.eclipse.jetty.http2:http2-server:10.0.24" +jetty-servlet = "org.eclipse.jetty.ee10:jetty-ee10-servlet:12.1.7" +# checkForUpdates: jetty-servlet10:10.0.24 +jetty-servlet10 = "org.eclipse.jetty:jetty-servlet:10.0.24" jsr305 = "com.google.code.findbugs:jsr305:3.0.2" junit = "junit:junit:4.13.2" -# 2.17+ require Java 11+ (not mentioned in release notes) -lincheck = "org.jetbrains.kotlinx:lincheck:2.16" +lincheck = "org.jetbrains.lincheck:lincheck:3.4" # Update notes / 2023-07-19 sergiitk: # Couldn't update to 5.4.0, updated to the last in 4.x line. Version 5.x breaks some tests. # Error log: https://github.com/grpc/grpc-java/pull/10359#issuecomment-1632834435 # Update notes / 2023-10-09 temawi: -# 4.11.0 Has been breaking the android integration tests as mockito now uses streams +# 4.5.0 Has been breaking the android integration tests as mockito now uses streams # (not available in API levels < 24). https://github.com/grpc/grpc-java/issues/10457 +# checkForUpdates: mockito-android:4.4.+ mockito-android = "org.mockito:mockito-android:4.4.0" +# checkForUpdates: mockito-core:4.4.+ mockito-core = "org.mockito:mockito-core:4.4.0" -netty-codec-http2 = { module = "io.netty:netty-codec-http2", version.ref = "netty" } -netty-handler-proxy = { module = "io.netty:netty-handler-proxy", version.ref = "netty" } -netty-tcnative = { module = "io.netty:netty-tcnative-boringssl-static", version.ref = "nettytcnative" } -netty-tcnative-classes = { module = "io.netty:netty-tcnative-classes", version.ref = "nettytcnative" } -netty-transport-epoll = { module = "io.netty:netty-transport-native-epoll", version.ref = "netty" } -netty-unix-common = { module = "io.netty:netty-transport-native-unix-common", version.ref = "netty" } +# Need to decide when we require users to absorb the breaking changes in 4.2 +# checkForUpdates: netty-codec-http2:4.1.+ +netty-codec-http2 = "io.netty:netty-codec-http2:4.1.132.Final" +# checkForUpdates: netty-handler-proxy:4.1.+ +netty-handler-proxy = "io.netty:netty-handler-proxy:4.1.132.Final" +# Keep the following references of tcnative version in sync whenever it's updated: +# SECURITY.md +netty-tcnative = "io.netty:netty-tcnative-boringssl-static:2.0.75.Final" +netty-tcnative-classes = "io.netty:netty-tcnative-classes:2.0.75.Final" +# checkForUpdates: netty-transport-epoll:4.1.+ +netty-transport-epoll = "io.netty:netty-transport-native-epoll:4.1.132.Final" +# checkForUpdates: netty-unix-common:4.1.+ +netty-unix-common = "io.netty:netty-transport-native-unix-common:4.1.132.Final" okhttp = "com.squareup.okhttp:okhttp:2.7.5" # okio 3.5+ uses Kotlin 1.9+ which requires Android Gradle Plugin 9+ +# checkForUpdates: okio:3.4.+ okio = "com.squareup.okio:okio:3.4.0" opencensus-api = { module = "io.opencensus:opencensus-api", version.ref = "opencensus" } opencensus-contrib-grpc-metrics = { module = "io.opencensus:opencensus-contrib-grpc-metrics", version.ref = "opencensus" } opencensus-exporter-stats-stackdriver = { module = "io.opencensus:opencensus-exporter-stats-stackdriver", version.ref = "opencensus" } opencensus-exporter-trace-stackdriver = { module = "io.opencensus:opencensus-exporter-trace-stackdriver", version.ref = "opencensus" } opencensus-impl = { module = "io.opencensus:opencensus-impl", version.ref = "opencensus" } -opentelemetry-api = "io.opentelemetry:opentelemetry-api:1.40.0" -opentelemetry-exporter-prometheus = "io.opentelemetry:opentelemetry-exporter-prometheus:1.40.0-alpha" -opentelemetry-gcp-resources = "io.opentelemetry.contrib:opentelemetry-gcp-resources:1.36.0-alpha" -opentelemetry-sdk-extension-autoconfigure = "io.opentelemetry:opentelemetry-sdk-extension-autoconfigure:1.40.0" -opentelemetry-sdk-testing = "io.opentelemetry:opentelemetry-sdk-testing:1.40.0" +opentelemetry-api = "io.opentelemetry:opentelemetry-api:1.60.1" +opentelemetry-exporter-prometheus = "io.opentelemetry:opentelemetry-exporter-prometheus:1.60.1-alpha" +opentelemetry-gcp-resources = "io.opentelemetry.contrib:opentelemetry-gcp-resources:1.54.0-alpha" +opentelemetry-sdk-extension-autoconfigure = "io.opentelemetry:opentelemetry-sdk-extension-autoconfigure:1.60.1" +opentelemetry-sdk-testing = "io.opentelemetry:opentelemetry-sdk-testing:1.60.1" perfmark-api = "io.perfmark:perfmark-api:0.27.0" -protobuf-java = { module = "com.google.protobuf:protobuf-java", version.ref = "protobuf" } -protobuf-java-util = { module = "com.google.protobuf:protobuf-java-util", version.ref = "protobuf" } -protobuf-javalite = { module = "com.google.protobuf:protobuf-javalite", version.ref = "protobuf" } -protobuf-protoc = { module = "com.google.protobuf:protoc", version.ref = "protobuf" } -re2j = "com.google.re2j:re2j:1.7" -robolectric = "org.robolectric:robolectric:4.13" -signature-android = "net.sf.androidscents.signature:android-api-level-19:4.4.2_r4" +# Not upgrading to 4.x as it is not yet ABI compatible. +# https://github.com/protocolbuffers/protobuf/issues/17247 +# checkForUpdates: protobuf-java:3.+ +protobuf-java = "com.google.protobuf:protobuf-java:3.25.8" +# checkForUpdates: protobuf-java-util:3.+ +protobuf-java-util = "com.google.protobuf:protobuf-java-util:3.25.8" +# checkForUpdates: protobuf-javalite:3.+ +protobuf-javalite = "com.google.protobuf:protobuf-javalite:3.25.8" +# checkForUpdates: protobuf-protoc:3.+ +protobuf-protoc = "com.google.protobuf:protoc:3.25.8" +re2j = "com.google.re2j:re2j:1.8" +robolectric = "org.robolectric:robolectric:4.16.1" +s2a-proto = "com.google.s2a.proto.v2:s2a-proto:0.1.3" +signature-android = "net.sf.androidscents.signature:android-api-level-21:5.0.1_r2" signature-java = "org.codehaus.mojo.signature:java18:1.0" -tomcat-embed-core = "org.apache.tomcat.embed:tomcat-embed-core:10.1.25" -tomcat-embed-core9 = "org.apache.tomcat.embed:tomcat-embed-core:9.0.89" -truth = "com.google.truth:truth:1.4.4" -undertow-servlet22 = "io.undertow:undertow-servlet:2.2.32.Final" -undertow-servlet = "io.undertow:undertow-servlet:2.3.14.Final" - -# Do not update: Pinned to the last version supporting Java 8. -# See https://checkstyle.sourceforge.io/releasenotes.html#Release_10.1 -checkstylejava8 = "com.puppycrawl.tools:checkstyle:9.3" -# See https://github.com/google/error-prone/releases/tag/v2.11.0 -errorprone-corejava8 = "com.google.errorprone:error_prone_core:2.10.0" +# 11.0.0+ require Java 17+ +# checkForUpdates: tomcat-embed-core:10.+ +tomcat-embed-core = "org.apache.tomcat.embed:tomcat-embed-core:10.1.52" +# checkForUpdates: tomcat-embed-core9:9.+ +tomcat-embed-core9 = "org.apache.tomcat.embed:tomcat-embed-core:9.0.115" +truth = "com.google.truth:truth:1.4.5" +# checkForUpdates: undertow-servlet22:2.2.+ +undertow-servlet22 = "io.undertow:undertow-servlet:2.2.38.Final" +undertow-servlet = "io.undertow:undertow-servlet:2.3.20.Final" diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index a4413138c96..d4081da476b 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.8-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/grpclb/BUILD.bazel b/grpclb/BUILD.bazel index 2dd24bb52a2..ca9975b7ce6 100644 --- a/grpclb/BUILD.bazel +++ b/grpclb/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") load("//:java_grpc_library.bzl", "java_grpc_library") @@ -20,6 +21,7 @@ java_library( "@com_google_protobuf//:protobuf_java_util", "@io_grpc_grpc_proto//:grpclb_load_balancer_java_proto", artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), ], ) diff --git a/grpclb/build.gradle b/grpclb/build.gradle index 93331053b09..e8896604f03 100644 --- a/grpclb/build.gradle +++ b/grpclb/build.gradle @@ -19,16 +19,20 @@ dependencies { implementation project(':grpc-core'), project(':grpc-protobuf'), project(':grpc-stub'), + project(':grpc-util'), libraries.guava, libraries.protobuf.java, libraries.protobuf.java.util runtimeOnly libraries.errorprone.annotations - compileOnly libraries.javax.annotation testImplementation libraries.truth, project(':grpc-inprocess'), testFixtures(project(':grpc-core')) - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } configureProtoCompilation() diff --git a/grpclb/src/generated/main/grpc/io/grpc/lb/v1/LoadBalancerGrpc.java b/grpclb/src/generated/main/grpc/io/grpc/lb/v1/LoadBalancerGrpc.java index c96c5400aac..b730eff7b37 100644 --- a/grpclb/src/generated/main/grpc/io/grpc/lb/v1/LoadBalancerGrpc.java +++ b/grpclb/src/generated/main/grpc/io/grpc/lb/v1/LoadBalancerGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/lb/v1/load_balancer.proto") @io.grpc.stub.annotations.GrpcGenerated public final class LoadBalancerGrpc { @@ -60,6 +57,21 @@ public LoadBalancerStub newStub(io.grpc.Channel channel, io.grpc.CallOptions cal return LoadBalancerStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static LoadBalancerBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public LoadBalancerBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerBlockingV2Stub(channel, callOptions); + } + }; + return LoadBalancerBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -147,6 +159,35 @@ public io.grpc.stub.StreamObserver balanceLoad /** * A stub to allow clients to do synchronous rpc calls to service LoadBalancer. */ + public static final class LoadBalancerBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private LoadBalancerBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected LoadBalancerBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerBlockingV2Stub(channel, callOptions); + } + + /** + *

+     * Bidirectional rpc to get a list of servers.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + balanceLoad() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getBalanceLoadMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service LoadBalancer. + */ public static final class LoadBalancerBlockingStub extends io.grpc.stub.AbstractBlockingStub { private LoadBalancerBlockingStub( diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java index d27c485dc13..fe928263ef9 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.protobuf.util.Timestamps; import io.grpc.ClientStreamTracer; import io.grpc.Metadata; @@ -29,7 +30,6 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicLongFieldUpdater; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java index 872937b03c1..bf9eea69af0 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java @@ -154,6 +154,8 @@ public void handleNameResolutionError(Status error) { } @Override + @Deprecated + @SuppressWarnings("InlineMeSuggester") public boolean canHandleEmptyAddressListFromNameResolution() { return true; } diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbNameResolver.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbNameResolver.java index d17587fb14d..60d02220e64 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbNameResolver.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbNameResolver.java @@ -21,6 +21,7 @@ import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.NameResolver; +import io.grpc.StatusOr; import io.grpc.internal.DnsNameResolver; import io.grpc.internal.SharedResourceHolder.Resource; import java.net.InetAddress; @@ -58,14 +59,22 @@ final class GrpclbNameResolver extends DnsNameResolver { } @Override - protected InternalResolutionResult doResolve(boolean forceTxt) { + protected ResolutionResult doResolve() { + ResolutionResult result = super.doResolve(); List balancerAddrs = resolveBalancerAddresses(); - InternalResolutionResult result = super.doResolve(!balancerAddrs.isEmpty()); if (!balancerAddrs.isEmpty()) { - result.attributes = - Attributes.newBuilder() + ResolutionResult.Builder resultBuilder = result.toBuilder() + .setAttributes(result.getAttributes().toBuilder() .set(GrpclbConstants.ATTR_LB_ADDRS, balancerAddrs) - .build(); + .build()); + if (!result.getAddressesOrError().hasValue()) { + // While ResolutionResult is powerful enough to communicate attributes simultaneously with + // an address resolution failure, LoadBalancer.ResolvedAddresses isn't yet and so the + // attributes are lost if addresses fail. GrpclbLB will be able to handle the lack of + // addresses since there are LB addresses, so discard the failure for now. + resultBuilder.setAddressesOrError(StatusOr.fromValue(Collections.emptyList())); + } + result = resultBuilder.build(); } return result; } diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java index 49b74645ec8..5ed84ade2f8 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java @@ -37,13 +37,16 @@ import io.grpc.ConnectivityStateInfo; import io.grpc.Context; import io.grpc.EquivalentAddressGroup; -import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.FixedResultPicker; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; -import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.Status; @@ -62,6 +65,7 @@ import io.grpc.lb.v1.Server; import io.grpc.lb.v1.ServerList; import io.grpc.stub.StreamObserver; +import io.grpc.util.ForwardingLoadBalancerHelper; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; @@ -119,7 +123,7 @@ final class GrpclbState { @VisibleForTesting static final RoundRobinEntry BUFFER_ENTRY = new RoundRobinEntry() { @Override - public PickResult picked(Metadata headers) { + public PickResult picked(PickSubchannelArgs args) { return PickResult.withNoResult(); } @@ -183,10 +187,20 @@ enum Mode { private List dropList = Collections.emptyList(); // Contains only non-drop, i.e., backends from the round-robin list from the balancer. private List backendList = Collections.emptyList(); + private ConnectivityState currentState = ConnectivityState.CONNECTING; private RoundRobinPicker currentPicker = new RoundRobinPicker(Collections.emptyList(), Arrays.asList(BUFFER_ENTRY)); private boolean requestConnectionPending; + // Child LoadBalancer and state for PICK_FIRST mode delegation. + private final LoadBalancerProvider pickFirstLbProvider; + @Nullable + private LoadBalancer pickFirstLb; + private ConnectivityState pickFirstLbState = CONNECTING; + private SubchannelPicker pickFirstLbPicker = new FixedResultPicker(PickResult.withNoResult()); + @Nullable + private GrpclbClientLoadRecorder currentPickFirstLoadRecorder; + GrpclbState( GrpclbConfig config, Helper helper, @@ -212,6 +226,9 @@ public void onSubchannelState( } else { this.subchannelPool = null; } + this.pickFirstLbProvider = checkNotNull( + LoadBalancerRegistry.getDefaultRegistry().getProvider("pick_first"), + "pick_first balancer not available"); this.time = checkNotNull(time, "time provider"); this.stopwatch = checkNotNull(stopwatch, "stopwatch"); this.timerService = checkNotNull(helper.getScheduledExecutorService(), "timerService"); @@ -309,6 +326,12 @@ void handleAddresses( void requestConnection() { requestConnectionPending = true; + // For PICK_FIRST mode with delegation, forward to the child LB. + if (config.getMode() == Mode.PICK_FIRST && pickFirstLb != null) { + pickFirstLb.requestConnection(); + requestConnectionPending = false; + return; + } for (RoundRobinEntry entry : currentPicker.pickList) { if (entry instanceof IdleSubchannelEntry) { ((IdleSubchannelEntry) entry).subchannel.requestConnection(); @@ -323,15 +346,23 @@ private void maybeUseFallbackBackends() { } // Balancer RPC should have either been broken or timed out. checkState(fallbackReason != null, "no reason to fallback"); - for (Subchannel subchannel : subchannels.values()) { - ConnectivityStateInfo stateInfo = subchannel.getAttributes().get(STATE_INFO).get(); - if (stateInfo.getState() == READY) { + // For PICK_FIRST mode with delegation, check the child LB's state. + if (config.getMode() == Mode.PICK_FIRST) { + if (pickFirstLb != null && pickFirstLbState == READY) { return; } - // If we do have balancer-provided backends, use one of its error in the error message if - // fail to fallback. - if (stateInfo.getState() == TRANSIENT_FAILURE) { - fallbackReason = stateInfo.getStatus(); + // For PICK_FIRST, we don't have individual subchannel states to use as fallback reason. + } else { + for (Subchannel subchannel : subchannels.values()) { + ConnectivityStateInfo stateInfo = subchannel.getAttributes().get(STATE_INFO).get(); + if (stateInfo.getState() == READY) { + return; + } + // If we do have balancer-provided backends, use one of its error in the error message if + // fail to fallback. + if (stateInfo.getState() == TRANSIENT_FAILURE) { + fallbackReason = stateInfo.getStatus(); + } } } // Fallback conditions met @@ -355,11 +386,12 @@ private void useFallbackBackends() { } private void shutdownLbComm() { + shutdownLbRpc(); if (lbCommChannel != null) { - lbCommChannel.shutdown(); + // The channel should have no RPCs at this point + lbCommChannel.shutdownNow(); lbCommChannel = null; } - shutdownLbRpc(); } private void shutdownLbRpc() { @@ -438,9 +470,10 @@ void shutdown() { subchannelPool.clear(); break; case PICK_FIRST: - if (!subchannels.isEmpty()) { - checkState(subchannels.size() == 1, "Excessive Subchannels: %s", subchannels); - subchannels.values().iterator().next().shutdown(); + // Shutdown the child pick_first LB which manages its own subchannels. + if (pickFirstLb != null) { + pickFirstLb.shutdown(); + pickFirstLb = null; } break; default: @@ -517,22 +550,17 @@ private void updateServerList( subchannels = Collections.unmodifiableMap(newSubchannelMap); break; case PICK_FIRST: - checkState(subchannels.size() <= 1, "Unexpected Subchannel count: %s", subchannels); - final Subchannel subchannel; + // Delegate to child pick_first LB for address management. + // Shutdown existing child LB if addresses become empty. if (newBackendAddrList.isEmpty()) { - if (subchannels.size() == 1) { - subchannel = subchannels.values().iterator().next(); - subchannel.shutdown(); - subchannels = Collections.emptyMap(); + if (pickFirstLb != null) { + pickFirstLb.shutdown(); + pickFirstLb = null; } break; } List eagList = new ArrayList<>(); - // Because for PICK_FIRST, we create a single Subchannel for all addresses, we have to - // attach the tokens to the EAG attributes and use TokenAttachingLoadRecorder to put them on - // headers. - // - // The PICK_FIRST code path doesn't cache Subchannels. + // Attach tokens to EAG attributes for TokenAttachingTracerFactory to retrieve. for (BackendAddressGroup bag : newBackendAddrList) { EquivalentAddressGroup origEag = bag.getAddresses(); Attributes eagAttrs = origEag.getAttributes(); @@ -542,30 +570,22 @@ private void updateServerList( } eagList.add(new EquivalentAddressGroup(origEag.getAddresses(), eagAttrs)); } - if (subchannels.isEmpty()) { - subchannel = - helper.createSubchannel( - CreateSubchannelArgs.newBuilder() - .setAddresses(eagList) - .setAttributes(createSubchannelAttrs()) - .build()); - subchannel.start(new SubchannelStateListener() { - @Override - public void onSubchannelState(ConnectivityStateInfo newState) { - handleSubchannelState(subchannel, newState); - } - }); - if (requestConnectionPending) { - subchannel.requestConnection(); - requestConnectionPending = false; - } - } else { - subchannel = subchannels.values().iterator().next(); - subchannel.updateAddresses(eagList); + + if (pickFirstLb == null) { + pickFirstLb = pickFirstLbProvider.newLoadBalancer(new PickFirstLbHelper()); + } + + // Pass addresses to child LB. + pickFirstLb.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(eagList) + .build()); + if (requestConnectionPending) { + pickFirstLb.requestConnection(); + requestConnectionPending = false; } - subchannels = Collections.singletonMap(eagList, subchannel); - newBackendList.add( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(loadRecorder))); + // Store the load recorder for token attachment. + currentPickFirstLoadRecorder = loadRecorder; break; default: throw new AssertionError("Missing case for " + config.getMode()); @@ -842,7 +862,11 @@ private void cleanUp() { private void maybeUpdatePicker() { List pickList; ConnectivityState state; - if (backendList.isEmpty()) { + // For PICK_FIRST mode with delegation, check if child LB exists instead of backendList. + boolean hasBackends = config.getMode() == Mode.PICK_FIRST + ? pickFirstLb != null + : !backendList.isEmpty(); + if (!hasBackends) { // Note balancer (is working) may enforce using fallback backends, and that fallback may // fail. So we should check if currently in fallback first. if (usingFallbackBackends) { @@ -894,26 +918,12 @@ private void maybeUpdatePicker() { } break; case PICK_FIRST: { - checkState(backendList.size() == 1, "Excessive backend entries: %s", backendList); - BackendEntry onlyEntry = backendList.get(0); - ConnectivityStateInfo stateInfo = - onlyEntry.subchannel.getAttributes().get(STATE_INFO).get(); - state = stateInfo.getState(); - switch (state) { - case READY: - pickList = Collections.singletonList(onlyEntry); - break; - case TRANSIENT_FAILURE: - pickList = - Collections.singletonList(new ErrorEntry(stateInfo.getStatus())); - break; - case CONNECTING: - pickList = Collections.singletonList(BUFFER_ENTRY); - break; - default: - pickList = Collections.singletonList( - new IdleSubchannelEntry(onlyEntry.subchannel, syncContext)); - } + // Use child LB's state and picker. Wrap the picker for token attachment. + state = pickFirstLbState; + TokenAttachingTracerFactory tracerFactory = + new TokenAttachingTracerFactory(currentPickFirstLoadRecorder); + pickList = Collections.singletonList( + new ChildLbPickerEntry(pickFirstLbPicker, tracerFactory)); break; } default: @@ -929,10 +939,12 @@ private void maybeUpdatePicker(ConnectivityState state, RoundRobinPicker picker) // Discard the new picker if we are sure it won't make any difference, in order to save // re-processing pending streams, and avoid unnecessary resetting of the pointer in // RoundRobinPicker. - if (picker.dropList.equals(currentPicker.dropList) + if (state.equals(currentState) + && picker.dropList.equals(currentPicker.dropList) && picker.pickList.equals(currentPicker.pickList)) { return; } + currentState = state; currentPicker = picker; helper.updateBalancingState(state, picker); } @@ -983,7 +995,7 @@ public boolean equals(Object other) { @VisibleForTesting interface RoundRobinEntry { - PickResult picked(Metadata headers); + PickResult picked(PickSubchannelArgs args); } @VisibleForTesting @@ -1024,7 +1036,8 @@ static final class BackendEntry implements RoundRobinEntry { } @Override - public PickResult picked(Metadata headers) { + public PickResult picked(PickSubchannelArgs args) { + Metadata headers = args.getHeaders(); headers.discardAll(GrpclbConstants.TOKEN_METADATA_KEY); if (token != null) { headers.put(GrpclbConstants.TOKEN_METADATA_KEY, token); @@ -1065,7 +1078,7 @@ static final class IdleSubchannelEntry implements RoundRobinEntry { } @Override - public PickResult picked(Metadata headers) { + public PickResult picked(PickSubchannelArgs args) { if (connectionRequested.compareAndSet(false, true)) { syncContext.execute(new Runnable() { @Override @@ -1108,7 +1121,7 @@ static final class ErrorEntry implements RoundRobinEntry { } @Override - public PickResult picked(Metadata headers) { + public PickResult picked(PickSubchannelArgs args) { return result; } @@ -1132,6 +1145,58 @@ public String toString() { } } + /** + * Entry that wraps a child LB's picker for PICK_FIRST mode delegation. + * Attaches TokenAttachingTracerFactory to the pick result for token propagation. + */ + @VisibleForTesting + static final class ChildLbPickerEntry implements RoundRobinEntry { + private final SubchannelPicker childPicker; + private final TokenAttachingTracerFactory tracerFactory; + + ChildLbPickerEntry(SubchannelPicker childPicker, TokenAttachingTracerFactory tracerFactory) { + this.childPicker = checkNotNull(childPicker, "childPicker"); + this.tracerFactory = checkNotNull(tracerFactory, "tracerFactory"); + } + + @Override + public PickResult picked(PickSubchannelArgs args) { + PickResult childResult = childPicker.pickSubchannel(args); + if (childResult.getSubchannel() == null) { + // No subchannel (e.g., buffer, error), return as-is. + return childResult; + } + // Wrap the pick result to attach tokens via the tracer factory. + return PickResult.withSubchannel( + childResult.getSubchannel(), tracerFactory, childResult.getAuthorityOverride()); + } + + @Override + public int hashCode() { + return Objects.hashCode(childPicker, tracerFactory); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof ChildLbPickerEntry)) { + return false; + } + ChildLbPickerEntry that = (ChildLbPickerEntry) other; + return Objects.equal(childPicker, that.childPicker) + && Objects.equal(tracerFactory, that.tracerFactory); + } + + @Override + public String toString() { + return "ChildLbPickerEntry(" + childPicker + ")"; + } + + @VisibleForTesting + SubchannelPicker getChildPicker() { + return childPicker; + } + } + @VisibleForTesting static final class RoundRobinPicker extends SubchannelPicker { @VisibleForTesting @@ -1174,7 +1239,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { if (pickIndex == pickList.size()) { pickIndex = 0; } - return pick.picked(args.getHeaders()); + return pick.picked(args); } } @@ -1189,4 +1254,28 @@ public String toString() { return MoreObjects.toStringHelper(RoundRobinPicker.class).toString(); } } + + /** + * Helper for the child pick_first LB in PICK_FIRST mode. Intercepts updateBalancingState() + * to store state and trigger the grpclb picker update with drops and token attachment. + */ + private final class PickFirstLbHelper extends ForwardingLoadBalancerHelper { + + @Override + protected Helper delegate() { + return helper; + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + pickFirstLbState = newState; + pickFirstLbPicker = newPicker; + // Trigger name resolution refresh on TRANSIENT_FAILURE or IDLE, similar to ROUND_ROBIN. + if (newState == TRANSIENT_FAILURE || newState == IDLE) { + helper.refreshNameResolution(); + } + maybeUseFallbackBackends(); + maybeUpdatePicker(); + } + } } diff --git a/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java b/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java index 8952ea1d8fb..f394c812b28 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java +++ b/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java @@ -19,14 +19,17 @@ import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import io.grpc.InternalServiceProviders; +import io.grpc.NameResolver; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; +import io.grpc.Uri; import io.grpc.internal.GrpcUtil; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; import java.util.Collection; import java.util.Collections; +import java.util.List; /** * A provider for {@code io.grpc.grpclb.GrpclbNameResolver}. @@ -56,27 +59,47 @@ public static final class Provider extends NameResolverProvider { private static final boolean IS_ANDROID = InternalServiceProviders .isAndroid(SecretGrpclbNameResolverProvider.class.getClassLoader()); + @Override + public NameResolver newNameResolver(Uri targetUri, final NameResolver.Args args) { + if (SCHEME.equals(targetUri.getScheme())) { + List pathSegments = targetUri.getPathSegments(); + Preconditions.checkArgument( + pathSegments.size() == 1, + "expected 1 path segment in target %s but found %s", + targetUri, + pathSegments); + return newNameResolver(targetUri.getAuthority(), pathSegments.get(0), args); + } else { + return null; + } + } + @Override public GrpclbNameResolver newNameResolver(URI targetUri, Args args) { + // TODO(jdcormie): Remove once RFC 3986 migration is complete. if (SCHEME.equals(targetUri.getScheme())) { String targetPath = Preconditions.checkNotNull(targetUri.getPath(), "targetPath"); Preconditions.checkArgument( targetPath.startsWith("/"), "the path component (%s) of the target (%s) must start with '/'", targetPath, targetUri); - String name = targetPath.substring(1); - return new GrpclbNameResolver( - targetUri.getAuthority(), - name, - args, - GrpcUtil.SHARED_CHANNEL_EXECUTOR, - Stopwatch.createUnstarted(), - IS_ANDROID); + return newNameResolver(targetUri.getAuthority(), targetPath.substring(1), args); } else { return null; } } + private GrpclbNameResolver newNameResolver( + String authority, String domainNameToResolve, final NameResolver.Args args) { + return new GrpclbNameResolver( + authority, + domainNameToResolve, + args, + GrpcUtil.SHARED_CHANNEL_EXECUTOR, + Stopwatch.createUnstarted(), + IS_ANDROID); + } + @Override public String getDefaultScheme() { return SCHEME; diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java index e489129676a..ef31b318cb5 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java @@ -72,6 +72,7 @@ import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.grpclb.GrpclbState.BackendEntry; +import io.grpc.grpclb.GrpclbState.ChildLbPickerEntry; import io.grpc.grpclb.GrpclbState.DropEntry; import io.grpc.grpclb.GrpclbState.ErrorEntry; import io.grpc.grpclb.GrpclbState.IdleSubchannelEntry; @@ -779,7 +780,9 @@ public void receiveNoBackendAndBalancerAddress() { verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); RoundRobinPicker picker = (RoundRobinPicker) pickerCaptor.getValue(); assertThat(picker.dropList).isEmpty(); - Status error = Iterables.getOnlyElement(picker.pickList).picked(new Metadata()).getStatus(); + PickSubchannelArgs args = mock(PickSubchannelArgs.class); + when(args.getHeaders()).thenReturn(new Metadata()); + Status error = Iterables.getOnlyElement(picker.pickList).picked(args).getStatus(); assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(error.getDescription()).isEqualTo("No backend or balancer addresses found"); } @@ -1915,6 +1918,7 @@ public void grpclbWorking_pickFirstMode() throws Exception { lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends1)); + // With delegation, the child pick_first creates the subchannel inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture()); CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue(); assertThat(createSubchannelArgs.getAddresses()) @@ -1922,42 +1926,41 @@ public void grpclbWorking_pickFirstMode() throws Exception { new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002"))); - // Initially IDLE - inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + // Child pick_first eagerly connects, so we start in CONNECTING + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue(); - - // Only one subchannel is created + // Only one subchannel is created by the child LB assertThat(mockSubchannels).hasSize(1); Subchannel subchannel = mockSubchannels.poll(); assertThat(picker0.dropList).containsExactly(null, null); - assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel, syncContext)); + assertThat(picker0.pickList).hasSize(1); + assertThat(picker0.pickList.get(0)).isInstanceOf(ChildLbPickerEntry.class); - // PICK_FIRST doesn't eagerly connect - verify(subchannel, never()).requestConnection(); - - // CONNECTING - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker1.dropList).containsExactly(null, null); - assertThat(picker1.pickList).containsExactly(BUFFER_ENTRY); + // Child pick_first eagerly calls requestConnection() + verify(subchannel).requestConnection(); // TRANSIENT_FAILURE Status error = Status.UNAVAILABLE.withDescription("Simulated connection error"); deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); - inOrder.verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker2.dropList).containsExactly(null, null); - assertThat(picker2.pickList).containsExactly(new ErrorEntry(error)); + // The child LB will notify our helper, which updates grpclb state + inOrder.verify(helper, atLeast(1)) + .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker1.dropList).containsExactly(null, null); + ChildLbPickerEntry failureEntry = (ChildLbPickerEntry) picker1.pickList.get(0); + PickResult failureResult = + failureEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)); + assertThat(failureResult.getStatus()).isEqualTo(error); // READY deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); - RoundRobinPicker picker3 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker3.dropList).containsExactly(null, null); - assertThat(picker3.pickList).containsExactly( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); - + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(READY), pickerCaptor.capture()); + RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker2.dropList).containsExactly(null, null); + ChildLbPickerEntry readyEntry = (ChildLbPickerEntry) picker2.pickList.get(0); + PickResult readyResult = + readyEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)); + assertThat(readyResult.getSubchannel()).isEqualTo(subchannel); // New server list with drops List backends2 = Arrays.asList( @@ -1968,37 +1971,40 @@ public void grpclbWorking_pickFirstMode() throws Exception { .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); lbResponseObserver.onNext(buildLbResponse(backends2)); - // new addresses will be updated to the existing subchannel - // createSubchannel() has ever been called only once + // Verify child LB is updated with new addresses, NOT recreated + inOrder.verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper, times(1)).createSubchannel(any(CreateSubchannelArgs.class)); assertThat(mockSubchannels).isEmpty(); + + // The child LB policy internally calls updateAddresses on the subchannel verify(subchannel).updateAddresses( eq(Arrays.asList( new EquivalentAddressGroup(backends2.get(0).addr, eagAttrsWithToken("token0001")), new EquivalentAddressGroup(backends2.get(2).addr, eagAttrsWithToken("token0004"))))); - inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); - RoundRobinPicker picker4 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker4.dropList).containsExactly( + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(READY), pickerCaptor.capture()); + RoundRobinPicker picker3 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker3.dropList).containsExactly( null, new DropEntry(getLoadRecorder(), "token0003"), null); - assertThat(picker4.pickList).containsExactly( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); + ChildLbPickerEntry updatedEntry = (ChildLbPickerEntry) picker3.pickList.get(0); + PickResult updatedResult = + updatedEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)); + assertThat(updatedResult.getSubchannel()).isEqualTo(subchannel); - // Subchannel goes IDLE, but PICK_FIRST will not try to reconnect + // Subchannel goes IDLE, grpclb state should follow deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); - RoundRobinPicker picker5 = (RoundRobinPicker) pickerCaptor.getValue(); - verify(subchannel, never()).requestConnection(); + RoundRobinPicker picker4 = (RoundRobinPicker) pickerCaptor.getValue(); - // ... until it's selected + // No new connection request should have happened yet (beyond the first eager one) + verify(subchannel, times(1)).requestConnection(); PickSubchannelArgs args = mock(PickSubchannelArgs.class); - PickResult pick = picker5.pickSubchannel(args); - assertThat(pick).isSameInstanceAs(PickResult.withNoResult()); - verify(subchannel).requestConnection(); - - // ... or requested by application - balancer.requestConnection(); + PickResult pick = picker4.pickSubchannel(args); + // Child pick_first picker returns withNoResult() when IDLE and requests connection + assertThat(pick.getSubchannel()).isNull(); verify(subchannel, times(2)).requestConnection(); + balancer.requestConnection(); + verify(subchannel, times(3)).requestConnection(); // PICK_FIRST doesn't use subchannelPool verify(subchannelPool, never()) @@ -2036,6 +2042,7 @@ public void grpclbWorking_pickFirstMode_lbSendsEmptyAddress() throws Exception { lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends1)); + // The child pick_first creates the first subchannel inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture()); CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue(); assertThat(createSubchannelArgs.getAddresses()) @@ -2043,56 +2050,43 @@ public void grpclbWorking_pickFirstMode_lbSendsEmptyAddress() throws Exception { new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002"))); - // Initially IDLE - inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + // Child pick_first eagerly connects, so initial state is CONNECTING + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue(); - - // Only one subchannel is created + // Verify subchannel creation by child LB assertThat(mockSubchannels).hasSize(1); Subchannel subchannel = mockSubchannels.poll(); assertThat(picker0.dropList).containsExactly(null, null); - assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel, syncContext)); - - // PICK_FIRST doesn't eagerly connect - verify(subchannel, never()).requestConnection(); + assertThat(picker0.pickList).hasSize(1); + assertThat(picker0.pickList.get(0)).isInstanceOf(ChildLbPickerEntry.class); - // CONNECTING - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker1.dropList).containsExactly(null, null); - assertThat(picker1.pickList).containsExactly(BUFFER_ENTRY); - - // TRANSIENT_FAILURE - Status error = Status.UNAVAILABLE.withDescription("Simulated connection error"); - deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); - inOrder.verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker2.dropList).containsExactly(null, null); - assertThat(picker2.pickList).containsExactly(new ErrorEntry(error)); + // Child pick_first eagerly calls requestConnection() + verify(subchannel).requestConnection(); // READY deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); - RoundRobinPicker picker3 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker3.dropList).containsExactly(null, null); - assertThat(picker3.pickList).containsExactly( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); - + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(READY), pickerCaptor.capture()); + RoundRobinPicker pickerReady = (RoundRobinPicker) pickerCaptor.getValue(); + // Verify the subchannel in the delegated picker + ChildLbPickerEntry readyEntry = (ChildLbPickerEntry) pickerReady.pickList.get(0); + assertThat( + readyEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel()) + .isEqualTo(subchannel); inOrder.verify(helper, never()) .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); - // Empty addresses from LB + // Empty addresses from LB - child LB is shutdown lbResponseObserver.onNext(buildLbResponse(Collections.emptyList())); - // new addresses will be updated to the existing subchannel + // Child LB is shutdown (which shuts down its subchannel) // createSubchannel() has ever been called only once inOrder.verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class)); assertThat(mockSubchannels).isEmpty(); verify(subchannel).shutdown(); // RPC error status includes message of no backends provided by balancer - inOrder.verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + inOrder.verify(helper, atLeast(1)) + .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); RoundRobinPicker errorPicker = (RoundRobinPicker) pickerCaptor.getValue(); assertThat(errorPicker.pickList) .containsExactly(new ErrorEntry(GrpclbState.NO_AVAILABLE_BACKENDS_STATUS)); @@ -2109,18 +2103,22 @@ public void grpclbWorking_pickFirstMode_lbSendsEmptyAddress() throws Exception { .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); lbResponseObserver.onNext(buildLbResponse(backends2)); - // new addresses will be updated to the existing subchannel - inOrder.verify(helper, times(1)).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); - subchannel = mockSubchannels.poll(); + // A NEW child LB and NEW subchannel are created upon recovery + inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); + assertThat(mockSubchannels).hasSize(1); + Subchannel subchannel2 = mockSubchannels.poll(); + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); // Subchannel became READY - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); - RoundRobinPicker picker4 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker4.pickList).containsExactly( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); + deliverSubchannelState(subchannel2, ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(READY), pickerCaptor.capture()); + RoundRobinPicker pickerFinal = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(pickerFinal.dropList).containsExactly( + null, new DropEntry(getLoadRecorder(), "token0003"), null); + ChildLbPickerEntry finalEntry = (ChildLbPickerEntry) pickerFinal.pickList.get(0); + assertThat( + finalEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel()) + .isEqualTo(subchannel2); } @Test @@ -2179,7 +2177,7 @@ private void pickFirstModeFallback(long timeout) throws Exception { // Fallback timer expires with no response fakeClock.forwardTime(timeout, TimeUnit.MILLISECONDS); - // Entering fallback mode + // Entering fallback mode - child LB is created for fallback backends inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture()); CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue(); assertThat(createSubchannelArgs.getAddresses()) @@ -2188,23 +2186,24 @@ private void pickFirstModeFallback(long timeout) throws Exception { assertThat(mockSubchannels).hasSize(1); Subchannel subchannel = mockSubchannels.poll(); - // Initially IDLE - inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + // child pick_first eagerly connects, so initial state is CONNECTING + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker0.pickList.get(0)).isInstanceOf(ChildLbPickerEntry.class); - // READY + // Initial eager connection request + verify(subchannel).requestConnection(); + // READY transition in fallback deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(READY), pickerCaptor.capture()); RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue(); assertThat(picker1.dropList).containsExactly(null, null); - assertThat(picker1.pickList).containsExactly( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(null))); + ChildLbPickerEntry readyEntry = (ChildLbPickerEntry) picker1.pickList.get(0); + assertThat( + readyEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel()) + .isEqualTo(subchannel); - assertThat(picker0.dropList).containsExactly(null, null); - assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel, syncContext)); - - - // Finally, an LB response, which brings us out of fallback + // Finally, an LB response arrives, which brings us out of fallback List backends1 = Arrays.asList( new ServerEntry("127.0.0.1", 2000, "token0001"), new ServerEntry("127.0.0.1", 2010, "token0002")); @@ -2213,20 +2212,42 @@ private void pickFirstModeFallback(long timeout) throws Exception { lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends1)); - // new addresses will be updated to the existing subchannel - // createSubchannel() has ever been called only once + // subchannel should be updated, NOT recreated inOrder.verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class)); assertThat(mockSubchannels).isEmpty(); + // The child LB internally calls updateAddresses on the existing subchannel verify(subchannel).updateAddresses( eq(Arrays.asList( new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002"))))); - inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(READY), pickerCaptor.capture()); RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue(); assertThat(picker2.dropList).containsExactly(null, null); - assertThat(picker2.pickList).containsExactly( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); + + // Verify subchannel is still the same via delegated picker + ChildLbPickerEntry updatedEntry = (ChildLbPickerEntry) picker2.pickList.get(0); + assertThat( + updatedEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)) + .getSubchannel()) + .isEqualTo(subchannel); + + // Subchannel goes IDLE, grpclb follows + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); + inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + RoundRobinPicker pickerIdle = (RoundRobinPicker) pickerCaptor.getValue(); + + // Verify connection is NOT eagerly requested again yet (still only the 1st request from start) + verify(subchannel, times(1)).requestConnection(); + + // Picking while IDLE triggers a new connection request + PickSubchannelArgs args = mock(PickSubchannelArgs.class); + PickResult pick = pickerIdle.pickSubchannel(args); + assertThat(pick.getSubchannel()).isNull(); // BUFFERing while IDLE + verify(subchannel, times(2)).requestConnection(); + + balancer.requestConnection(); + verify(subchannel, times(3)).requestConnection(); // PICK_FIRST doesn't use subchannelPool verify(subchannelPool, never()) @@ -2260,6 +2281,8 @@ public void switchMode() throws Exception { List backends1 = Arrays.asList( new ServerEntry("127.0.0.1", 2000, "token0001"), new ServerEntry("127.0.0.1", 2010, "token0002")); + + // RR Mode: Ensure no updates before initial response inOrder.verify(helper, never()) .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); lbResponseObserver.onNext(buildInitialResponse()); @@ -2284,7 +2307,6 @@ public void switchMode() throws Exception { Collections.emptyList(), grpclbBalancerList, GrpclbConfig.create(Mode.PICK_FIRST)); - // GrpclbState will be shutdown, and a new one will be created assertThat(oobChannel.isShutdown()).isTrue(); verify(subchannelPool) @@ -2303,13 +2325,13 @@ public void switchMode() throws Exception { InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build()) .build())); - // Simulate receiving LB response + // Simulate receiving LB response for PICK_FIRST inOrder.verify(helper, never()) .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends1)); - // PICK_FIRST Subchannel + // PICK_FIRST Subchannel: child LB creates it inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture()); CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue(); assertThat(createSubchannelArgs.getAddresses()) @@ -2317,7 +2339,9 @@ public void switchMode() throws Exception { new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002"))); - inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + // Child pick_first eagerly connects, so initial state is CONNECTING (not IDLE) + inOrder.verify(helper, atLeast(1)) + .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); } private static Attributes eagAttrsWithToken(String token) { @@ -2344,7 +2368,7 @@ public void switchMode_nullLbPolicy() throws Exception { InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build()) .build())); - // Simulate receiving LB response + // Simulate receiving LB response (Initial default mode: ROUND_ROBIN) List backends1 = Arrays.asList( new ServerEntry("127.0.0.1", 2000, "token0001"), new ServerEntry("127.0.0.1", 2010, "token0002")); @@ -2391,13 +2415,13 @@ public void switchMode_nullLbPolicy() throws Exception { InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build()) .build())); - // Simulate receiving LB response + // Simulate receiving LB response for PICK_FIRST inOrder.verify(helper, never()) .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends1)); - // PICK_FIRST Subchannel + // PICK_FIRST Subchannel: with delegation, child LB creates the subchannel inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture()); CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue(); assertThat(createSubchannelArgs.getAddresses()) @@ -2405,7 +2429,9 @@ public void switchMode_nullLbPolicy() throws Exception { new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002"))); - inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + // Child pick_first eagerly connects, so state is CONNECTING (not IDLE) + inOrder.verify(helper, atLeast(1)) + .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); } @Test diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbNameResolverTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbNameResolverTest.java index c195a78e6f4..a90556a01b0 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbNameResolverTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbNameResolverTest.java @@ -20,7 +20,6 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -96,7 +95,6 @@ public void close(Executor instance) {} } @Captor private ArgumentCaptor resultCaptor; - @Captor private ArgumentCaptor errorCaptor; @Mock private ServiceConfigParser serviceConfigParser; @Mock private NameResolver.Listener2 mockListener; @@ -154,7 +152,7 @@ public List resolveSrv(String host) throws Exception { verify(mockListener).onResult2(resultCaptor.capture()); ResolutionResult result = resultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); assertThat(result.getAttributes()).isEqualTo(Attributes.EMPTY); assertThat(result.getServiceConfig()).isNull(); } @@ -196,7 +194,7 @@ public ConfigOrError answer(InvocationOnMock invocation) { ResolutionResult result = resultCaptor.getValue(); InetSocketAddress resolvedBackendAddr = (InetSocketAddress) Iterables.getOnlyElement( - Iterables.getOnlyElement(result.getAddresses()).getAddresses()); + Iterables.getOnlyElement(result.getAddressesOrError().getValue()).getAddresses()); assertThat(resolvedBackendAddr.getAddress()).isEqualTo(backendAddr); EquivalentAddressGroup resolvedBalancerAddr = Iterables.getOnlyElement(result.getAttributes().get(GrpclbConstants.ATTR_LB_ADDRS)); @@ -227,7 +225,7 @@ public void resolve_nullResourceResolver() throws Exception { assertThat(fakeClock.runDueTasks()).isEqualTo(1); verify(mockListener).onResult2(resultCaptor.capture()); ResolutionResult result = resultCaptor.getValue(); - assertThat(result.getAddresses()) + assertThat(result.getAddressesOrError().getValue()) .containsExactly( new EquivalentAddressGroup(new InetSocketAddress(backendAddr, DEFAULT_PORT))); assertThat(result.getAttributes()).isEqualTo(Attributes.EMPTY); @@ -245,8 +243,8 @@ public void resolve_nullResourceResolver_addressFailure() throws Exception { resolver.start(mockListener); assertThat(fakeClock.runDueTasks()).isEqualTo(1); - verify(mockListener).onError(errorCaptor.capture()); - Status errorStatus = errorCaptor.getValue(); + verify(mockListener).onResult2(resultCaptor.capture()); + Status errorStatus = resultCaptor.getValue().getAddressesOrError().getStatus(); assertThat(errorStatus.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(errorStatus.getCause()).hasMessageThat().contains("no addr"); } @@ -274,7 +272,7 @@ public void resolve_addressFailure_stillLookUpBalancersAndServiceConfig() throws assertThat(fakeClock.runDueTasks()).isEqualTo(1); verify(mockListener).onResult2(resultCaptor.capture()); ResolutionResult result = resultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); EquivalentAddressGroup resolvedBalancerAddr = Iterables.getOnlyElement(result.getAttributes().get(GrpclbConstants.ATTR_LB_ADDRS)); assertThat(resolvedBalancerAddr.getAttributes().get(GrpclbConstants.ATTR_LB_ADDR_AUTHORITY)) @@ -311,7 +309,7 @@ public void resolveAll_balancerLookupFails_stillLookUpServiceConfig() throws Exc InetSocketAddress resolvedBackendAddr = (InetSocketAddress) Iterables.getOnlyElement( - Iterables.getOnlyElement(result.getAddresses()).getAddresses()); + Iterables.getOnlyElement(result.getAddressesOrError().getValue()).getAddresses()); assertThat(resolvedBackendAddr.getAddress()).isEqualTo(backendAddr); assertThat(result.getAttributes().get(GrpclbConstants.ATTR_LB_ADDRS)).isNull(); verify(mockAddressResolver).resolveAddress(hostName); @@ -320,7 +318,7 @@ public void resolveAll_balancerLookupFails_stillLookUpServiceConfig() throws Exc } @Test - public void resolve_addressAndBalancersLookupFail_neverLookupServiceConfig() throws Exception { + public void resolve_addressAndBalancersLookupFail_stillLookupServiceConfig() throws Exception { AddressResolver mockAddressResolver = mock(AddressResolver.class); when(mockAddressResolver.resolveAddress(anyString())) .thenThrow(new UnknownHostException("I really tried")); @@ -335,11 +333,11 @@ public void resolve_addressAndBalancersLookupFail_neverLookupServiceConfig() thr resolver.start(mockListener); assertThat(fakeClock.runDueTasks()).isEqualTo(1); - verify(mockListener).onError(errorCaptor.capture()); - Status errorStatus = errorCaptor.getValue(); + verify(mockListener).onResult2(resultCaptor.capture()); + Status errorStatus = resultCaptor.getValue().getAddressesOrError().getStatus(); assertThat(errorStatus.getCode()).isEqualTo(Code.UNAVAILABLE); verify(mockAddressResolver).resolveAddress(hostName); - verify(mockResourceResolver, never()).resolveTxt("_grpc_config." + hostName); + verify(mockResourceResolver).resolveTxt("_grpc_config." + hostName); verify(mockResourceResolver).resolveSrv("_grpclb._tcp." + hostName); } } diff --git a/grpclb/src/test/java/io/grpc/grpclb/SecretGrpclbNameResolverProviderTest.java b/grpclb/src/test/java/io/grpc/grpclb/SecretGrpclbNameResolverProviderTest.java index 24b1c781f58..e9ed92a54d0 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/SecretGrpclbNameResolverProviderTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/SecretGrpclbNameResolverProviderTest.java @@ -17,6 +17,8 @@ package io.grpc.grpclb; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.TruthJUnit.assume; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; @@ -24,15 +26,19 @@ import io.grpc.NameResolver; import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.SynchronizationContext; +import io.grpc.Uri; import io.grpc.internal.DnsNameResolverProvider; import io.grpc.internal.GrpcUtil; import java.net.URI; +import java.util.Arrays; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** Unit tests for {@link SecretGrpclbNameResolverProvider}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class SecretGrpclbNameResolverProviderTest { private final SynchronizationContext syncContext = new SynchronizationContext( @@ -53,6 +59,13 @@ public void uncaughtException(Thread t, Throwable e) { private SecretGrpclbNameResolverProvider.Provider provider = new SecretGrpclbNameResolverProvider.Provider(); + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + @Test public void isAvailable() { assertThat(provider.isAvailable()).isTrue(); @@ -66,43 +79,65 @@ public void priority_shouldBeHigherThanDefaultDnsNameResolver() { } @Test - public void newNameResolver() { - assertThat(provider.newNameResolver(URI.create("dns:///localhost:443"), args)) + public void newNameResolverReturnsCorrectType() { + assertThat(newNameResolver("dns:///localhost:443", args)) .isInstanceOf(GrpclbNameResolver.class); - assertThat(provider.newNameResolver(URI.create("notdns:///localhost:443"), args)).isNull(); + assertThat(newNameResolver("notdns:///localhost:443", args)).isNull(); } @Test public void invalidDnsName() throws Exception { - testInvalidUri(new URI("dns", null, "/[invalid]", null)); + testInvalidUri("dns:/%5Binvalid%5D"); } @Test public void validIpv6() throws Exception { - testValidUri(new URI("dns", null, "/[::1]", null)); + testValidUri("dns:/%5B::1%5D"); } @Test public void validDnsNameWithoutPort() throws Exception { - testValidUri(new URI("dns", null, "/foo.googleapis.com", null)); + testValidUri("dns:/foo.googleapis.com"); } @Test public void validDnsNameWithPort() throws Exception { - testValidUri(new URI("dns", null, "/foo.googleapis.com:456", null)); + testValidUri("dns:/foo.googleapis.com:456"); + } + + @Test + public void newNameResolver_rejectsExtraPathSegments() { + assume().that(enableRfc3986UrisParam).isTrue(); + IllegalArgumentException iae = + assertThrows( + IllegalArgumentException.class, + () -> newNameResolver("dns:///localhost:443/extras", args)); + assertThat(iae).hasMessageThat().contains("expected 1 path segment in target"); } - private void testInvalidUri(URI uri) { + @Test + public void newNameResolver_toleratesExtraPathSegments() { + assume().that(enableRfc3986UrisParam).isFalse(); + newNameResolver("dns:///localhost:443/extras", args); + } + + private void testInvalidUri(String uri) { try { - provider.newNameResolver(uri, args); + newNameResolver(uri, args); fail("Should have failed"); } catch (IllegalArgumentException e) { // expected } } - private void testValidUri(URI uri) { - GrpclbNameResolver resolver = provider.newNameResolver(uri, args); + private void testValidUri(String uri) { + NameResolver resolver = newNameResolver(uri, args); assertThat(resolver).isNotNull(); } + + private NameResolver newNameResolver(String uriString, NameResolver.Args args) { + return enableRfc3986UrisParam + ? provider.newNameResolver(Uri.create(uriString), args) + : provider.newNameResolver(URI.create(uriString), args); + } } diff --git a/inprocess/BUILD.bazel b/inprocess/BUILD.bazel index bef38612713..e9c5001c5ec 100644 --- a/inprocess/BUILD.bazel +++ b/inprocess/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( diff --git a/inprocess/build.gradle b/inprocess/build.gradle index edc97883b50..075968ccb9a 100644 --- a/inprocess/build.gradle +++ b/inprocess/build.gradle @@ -22,8 +22,16 @@ dependencies { testFixtures(project(':grpc-core')) testImplementation libraries.guava.testlib - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("javadoc").configure { diff --git a/inprocess/src/main/java/io/grpc/inprocess/AnonymousInProcessSocketAddress.java b/inprocess/src/main/java/io/grpc/inprocess/AnonymousInProcessSocketAddress.java index 5f6486e335d..c458857d70b 100644 --- a/inprocess/src/main/java/io/grpc/inprocess/AnonymousInProcessSocketAddress.java +++ b/inprocess/src/main/java/io/grpc/inprocess/AnonymousInProcessSocketAddress.java @@ -18,11 +18,13 @@ import static com.google.common.base.Preconditions.checkState; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.ExperimentalApi; import java.io.IOException; +import java.io.NotSerializableException; +import java.io.ObjectOutputStream; import java.net.SocketAddress; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Custom SocketAddress class for {@link InProcessTransport}, for @@ -34,8 +36,13 @@ public final class AnonymousInProcessSocketAddress extends SocketAddress { @Nullable @GuardedBy("this") + @SuppressWarnings("serial") private InProcessServer server; + private void writeObject(ObjectOutputStream out) throws IOException { + throw new NotSerializableException("AnonymousInProcessSocketAddress is not serializable"); + } + /** Creates a new AnonymousInProcessSocketAddress. */ public AnonymousInProcessSocketAddress() { } diff --git a/inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java b/inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java index c000b66b2a2..9b33b3d3618 100644 --- a/inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java +++ b/inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.inprocess.InProcessTransport.isEnabledSupportTracingMessageSizes; import com.google.errorprone.annotations.DoNotCall; import io.grpc.ChannelCredentials; @@ -94,6 +95,7 @@ public static InProcessChannelBuilder forAddress(String name, int port) { private ScheduledExecutorService scheduledExecutorService; private int maxInboundMetadataSize = Integer.MAX_VALUE; private boolean transportIncludeStatusCause = false; + private long assumedMessageSize = -1; private InProcessChannelBuilder(@Nullable SocketAddress directAddress, @Nullable String target) { @@ -117,10 +119,9 @@ public ClientTransportFactory buildClientTransportFactory() { managedChannelImplBuilder.setStatsRecordStartedRpcs(false); managedChannelImplBuilder.setStatsRecordFinishedRpcs(false); managedChannelImplBuilder.setStatsRecordRetryMetrics(false); - - // By default, In-process transport should not be retriable as that leaks memory. Since - // there is no wire, bytes aren't calculated so buffer limit isn't respected - managedChannelImplBuilder.disableRetry(); + if (!isEnabledSupportTracingMessageSizes) { + managedChannelImplBuilder.disableRetry(); + } } @Internal @@ -225,9 +226,24 @@ public InProcessChannelBuilder propagateCauseWithStatus(boolean enable) { return this; } + /** + * Assumes RPC messages are the specified size. This avoids serializing + * messages for metrics and retry memory tracking. This can dramatically + * improve performance when accurate message sizes are not needed and if + * nothing else needs the serialized message. + * @param assumedMessageSize length of InProcess transport's messageSize. + * @return this + * @throws IllegalArgumentException if assumedMessageSize is negative. + */ + public InProcessChannelBuilder assumedMessageSize(long assumedMessageSize) { + checkArgument(assumedMessageSize >= 0, "assumedMessageSize must be >= 0"); + this.assumedMessageSize = assumedMessageSize; + return this; + } + ClientTransportFactory buildTransportFactory() { - return new InProcessClientTransportFactory( - scheduledExecutorService, maxInboundMetadataSize, transportIncludeStatusCause); + return new InProcessClientTransportFactory(scheduledExecutorService, + maxInboundMetadataSize, transportIncludeStatusCause, assumedMessageSize); } void setStatsEnabled(boolean value) { @@ -243,15 +259,17 @@ static final class InProcessClientTransportFactory implements ClientTransportFac private final int maxInboundMetadataSize; private boolean closed; private final boolean includeCauseWithStatus; + private long assumedMessageSize; private InProcessClientTransportFactory( @Nullable ScheduledExecutorService scheduledExecutorService, - int maxInboundMetadataSize, boolean includeCauseWithStatus) { + int maxInboundMetadataSize, boolean includeCauseWithStatus, long assumedMessageSize) { useSharedTimer = scheduledExecutorService == null; timerService = useSharedTimer ? SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE) : scheduledExecutorService; this.maxInboundMetadataSize = maxInboundMetadataSize; this.includeCauseWithStatus = includeCauseWithStatus; + this.assumedMessageSize = assumedMessageSize; } @Override @@ -263,7 +281,7 @@ public ConnectionClientTransport newClientTransport( // TODO(carl-mastrangelo): Pass channelLogger in. return new InProcessTransport( addr, maxInboundMetadataSize, options.getAuthority(), options.getUserAgent(), - options.getEagAttributes(), includeCauseWithStatus); + options.getEagAttributes(), includeCauseWithStatus, assumedMessageSize); } @Override diff --git a/inprocess/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java b/inprocess/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java index 190f67603c3..b2004426aae 100644 --- a/inprocess/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java +++ b/inprocess/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java @@ -24,6 +24,7 @@ import io.grpc.ExperimentalApi; import io.grpc.ForwardingServerBuilder; import io.grpc.Internal; +import io.grpc.MetricRecorder; import io.grpc.ServerBuilder; import io.grpc.ServerStreamTracer; import io.grpc.internal.FixedObjectPool; @@ -120,7 +121,8 @@ private InProcessServerBuilder(SocketAddress listenAddress) { final class InProcessClientTransportServersBuilder implements ClientTransportServersBuilder { @Override public InternalServer buildClientTransportServers( - List streamTracerFactories) { + List streamTracerFactories, + MetricRecorder metricRecorder) { return buildTransportServers(streamTracerFactories); } } diff --git a/inprocess/src/main/java/io/grpc/inprocess/InProcessTransport.java b/inprocess/src/main/java/io/grpc/inprocess/InProcessTransport.java index ae8ad143d2c..a92f10fd5c5 100644 --- a/inprocess/src/main/java/io/grpc/inprocess/InProcessTransport.java +++ b/inprocess/src/main/java/io/grpc/inprocess/InProcessTransport.java @@ -18,12 +18,13 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; -import static java.lang.Math.max; import com.google.common.base.MoreObjects; -import com.google.common.base.Optional; +import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.CheckReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; @@ -35,6 +36,7 @@ import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.InternalMetadata; +import io.grpc.KnownLength; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.SecurityLevel; @@ -52,13 +54,14 @@ import io.grpc.internal.ManagedClientTransport; import io.grpc.internal.NoopClientStream; import io.grpc.internal.ObjectPool; -import io.grpc.internal.ServerListener; import io.grpc.internal.ServerStream; import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.StreamListener; +import java.io.ByteArrayInputStream; import java.io.InputStream; import java.net.SocketAddress; import java.util.ArrayDeque; @@ -73,21 +76,20 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; @ThreadSafe final class InProcessTransport implements ServerTransport, ConnectionClientTransport { private static final Logger log = Logger.getLogger(InProcessTransport.class.getName()); + static boolean isEnabledSupportTracingMessageSizes = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_SUPPORT_TRACING_MESSAGE_SIZES", false); private final InternalLogId logId; private final SocketAddress address; private final int clientMaxInboundMetadataSize; private final String authority; private final String userAgent; - private final Optional optionalServerListener; private int serverMaxInboundMetadataSize; private final boolean includeCauseWithStatus; private ObjectPool serverSchedulerPool; @@ -95,6 +97,8 @@ final class InProcessTransport implements ServerTransport, ConnectionClientTrans private ServerTransportListener serverTransportListener; private Attributes serverStreamAttributes; private ManagedClientTransport.Listener clientTransportListener; + // The size is assumed from the sender's side. + private final long assumedMessageSize; @GuardedBy("this") private boolean shutdown; @GuardedBy("this") @@ -134,9 +138,9 @@ protected void handleNotInUse() { } }; - private InProcessTransport(SocketAddress address, int maxInboundMetadataSize, String authority, + public InProcessTransport(SocketAddress address, int maxInboundMetadataSize, String authority, String userAgent, Attributes eagAttrs, - Optional optionalServerListener, boolean includeCauseWithStatus) { + boolean includeCauseWithStatus, long assumedMessageSize) { this.address = address; this.clientMaxInboundMetadataSize = maxInboundMetadataSize; this.authority = authority; @@ -148,47 +152,23 @@ private InProcessTransport(SocketAddress address, int maxInboundMetadataSize, St .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, address) .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, address) .build(); - this.optionalServerListener = optionalServerListener; logId = InternalLogId.allocate(getClass(), address.toString()); this.includeCauseWithStatus = includeCauseWithStatus; - } - - public InProcessTransport( - SocketAddress address, int maxInboundMetadataSize, String authority, String userAgent, - Attributes eagAttrs, boolean includeCauseWithStatus) { - this(address, maxInboundMetadataSize, authority, userAgent, eagAttrs, - Optional.absent(), includeCauseWithStatus); - } - - InProcessTransport( - String name, int maxInboundMetadataSize, String authority, String userAgent, - Attributes eagAttrs, ObjectPool serverSchedulerPool, - List serverStreamTracerFactories, - ServerListener serverListener, boolean includeCauseWithStatus) { - this(new InProcessSocketAddress(name), maxInboundMetadataSize, authority, userAgent, eagAttrs, - Optional.of(serverListener), includeCauseWithStatus); - this.serverMaxInboundMetadataSize = maxInboundMetadataSize; - this.serverSchedulerPool = serverSchedulerPool; - this.serverStreamTracerFactories = serverStreamTracerFactories; + this.assumedMessageSize = assumedMessageSize; } @CheckReturnValue @Override public synchronized Runnable start(ManagedClientTransport.Listener listener) { this.clientTransportListener = listener; - if (optionalServerListener.isPresent()) { + InProcessServer server = InProcessServer.findServer(address); + if (server != null) { + serverMaxInboundMetadataSize = server.getMaxInboundMetadataSize(); + serverSchedulerPool = server.getScheduledExecutorServicePool(); serverScheduler = serverSchedulerPool.getObject(); - serverTransportListener = optionalServerListener.get().transportCreated(this); - } else { - InProcessServer server = InProcessServer.findServer(address); - if (server != null) { - serverMaxInboundMetadataSize = server.getMaxInboundMetadataSize(); - serverSchedulerPool = server.getScheduledExecutorServicePool(); - serverScheduler = serverSchedulerPool.getObject(); - serverStreamTracerFactories = server.getStreamTracerFactories(); - // Must be semi-initialized; past this point, can begin receiving requests - serverTransportListener = server.register(this); - } + serverStreamTracerFactories = server.getStreamTracerFactories(); + // Must be semi-initialized; past this point, can begin receiving requests + serverTransportListener = server.register(this); } if (serverTransportListener == null) { shutdownStatus = Status.UNAVAILABLE.withDescription("Could not find server: " + address); @@ -266,7 +246,7 @@ public synchronized void ping(final PingCallback callback, Executor executor) { executor.execute(new Runnable() { @Override public void run() { - callback.onFailure(shutdownStatus.asRuntimeException()); + callback.onFailure(shutdownStatus); } }); } else { @@ -348,7 +328,7 @@ private synchronized void notifyShutdown(Status s) { return; } shutdown = true; - clientTransportListener.transportShutdown(s); + clientTransportListener.transportShutdown(s, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); } private synchronized void notifyTerminated() { @@ -507,6 +487,25 @@ private void clientCancelled(Status status) { @Override public void writeMessage(InputStream message) { + long messageLength = 0; + if (isEnabledSupportTracingMessageSizes) { + try { + if (assumedMessageSize != -1) { + messageLength = assumedMessageSize; + } else if (message instanceof KnownLength || message instanceof ByteArrayInputStream) { + messageLength = message.available(); + } else { + InputStream oldMessage = message; + byte[] payload = ByteStreams.toByteArray(message); + messageLength = payload.length; + message = new ByteArrayInputStream(payload); + oldMessage.close(); + } + } catch (Exception e) { + throw new RuntimeException("Error processing the message length", e); + } + } + synchronized (this) { if (closed) { return; @@ -515,6 +514,13 @@ public void writeMessage(InputStream message) { statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1); clientStream.statsTraceCtx.inboundMessage(outboundSeqNo); clientStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1); + if (isEnabledSupportTracingMessageSizes) { + statsTraceCtx.outboundUncompressedSize(messageLength); + statsTraceCtx.outboundWireSize(messageLength); + // messageLength should be same at receiver's end as no actual wire is involved. + clientStream.statsTraceCtx.inboundUncompressedSize(messageLength); + clientStream.statsTraceCtx.inboundWireSize(messageLength); + } outboundSeqNo++; StreamListener.MessageProducer producer = new SingleMessageProducer(message); if (clientRequested > 0) { @@ -524,7 +530,6 @@ public void writeMessage(InputStream message) { clientReceiveQueue.add(producer); } } - syncContext.drain(); } @@ -601,7 +606,7 @@ public void close(Status status, Metadata trailers) { notifyClientClose(status, trailers); } - /** clientStream.serverClosed() must be called before this method */ + /** clientStream.serverClosed() must be called before this method. */ private void notifyClientClose(Status status, Metadata trailers) { Status clientStatus = cleanStatus(status, includeCauseWithStatus); synchronized (this) { @@ -778,6 +783,24 @@ private void serverClosed(Status serverListenerStatus, Status serverTracerStatus @Override public void writeMessage(InputStream message) { + long messageLength = 0; + if (isEnabledSupportTracingMessageSizes) { + try { + if (assumedMessageSize != -1) { + messageLength = assumedMessageSize; + } else if (message instanceof KnownLength || message instanceof ByteArrayInputStream) { + messageLength = message.available(); + } else { + InputStream oldMessage = message; + byte[] payload = ByteStreams.toByteArray(message); + messageLength = payload.length; + message = new ByteArrayInputStream(payload); + oldMessage.close(); + } + } catch (Exception e) { + throw new RuntimeException("Error processing the message length", e); + } + } synchronized (this) { if (closed) { return; @@ -786,6 +809,13 @@ public void writeMessage(InputStream message) { statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1); serverStream.statsTraceCtx.inboundMessage(outboundSeqNo); serverStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1); + if (isEnabledSupportTracingMessageSizes) { + statsTraceCtx.outboundUncompressedSize(messageLength); + statsTraceCtx.outboundWireSize(messageLength); + // messageLength should be same at receiver's end as no actual wire is involved. + serverStream.statsTraceCtx.inboundUncompressedSize(messageLength); + serverStream.statsTraceCtx.inboundWireSize(messageLength); + } outboundSeqNo++; StreamListener.MessageProducer producer = new SingleMessageProducer(message); if (serverRequested > 0) { @@ -909,8 +939,7 @@ public void setMaxOutboundMessageSize(int maxSize) {} @Override public void setDeadline(Deadline deadline) { headers.discardAll(TIMEOUT_KEY); - long effectiveTimeout = max(0, deadline.timeRemaining(TimeUnit.NANOSECONDS)); - headers.put(TIMEOUT_KEY, effectiveTimeout); + headers.put(TIMEOUT_KEY, deadline.timeRemaining(TimeUnit.NANOSECONDS)); } @Override diff --git a/inprocess/src/main/java/io/grpc/inprocess/InternalInProcess.java b/inprocess/src/main/java/io/grpc/inprocess/InternalInProcess.java deleted file mode 100644 index 680373533c8..00000000000 --- a/inprocess/src/main/java/io/grpc/inprocess/InternalInProcess.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.inprocess; - -import io.grpc.Attributes; -import io.grpc.Internal; -import io.grpc.ServerStreamTracer; -import io.grpc.internal.ConnectionClientTransport; -import io.grpc.internal.ObjectPool; -import io.grpc.internal.ServerListener; -import java.util.List; -import java.util.concurrent.ScheduledExecutorService; - -/** - * Internal {@link InProcessTransport} accessor. - * - *

This is intended for use by io.grpc.internal, and the specifically - * supported transport packages. - */ -@Internal -public final class InternalInProcess { - - private InternalInProcess() {} - - /** - * Creates a new InProcessTransport. - * - *

When started, the transport will be registered with the given - * {@link ServerListener}. - */ - @Internal - public static ConnectionClientTransport createInProcessTransport( - String name, - int maxInboundMetadataSize, - String authority, - String userAgent, - Attributes eagAttrs, - ObjectPool serverSchedulerPool, - List serverStreamTracerFactories, - ServerListener serverListener, - boolean includeCauseWithStatus) { - return new InProcessTransport( - name, - maxInboundMetadataSize, - authority, - userAgent, - eagAttrs, - serverSchedulerPool, - serverStreamTracerFactories, - serverListener, - includeCauseWithStatus); - } -} diff --git a/inprocess/src/test/java/io/grpc/inprocess/AnonymousInProcessTransportTest.java b/inprocess/src/test/java/io/grpc/inprocess/AnonymousInProcessTransportTest.java index a78a604eac3..7bf884c9ff9 100644 --- a/inprocess/src/test/java/io/grpc/inprocess/AnonymousInProcessTransportTest.java +++ b/inprocess/src/test/java/io/grpc/inprocess/AnonymousInProcessTransportTest.java @@ -52,6 +52,6 @@ protected InternalServer newServer( protected ManagedClientTransport newClientTransport(InternalServer server) { return new InProcessTransport( address, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, - testAuthority(server), USER_AGENT, eagAttrs(), false); + testAuthority(server), USER_AGENT, eagAttrs(), false, -1); } } diff --git a/inprocess/src/test/java/io/grpc/inprocess/InProcessTransportTest.java b/inprocess/src/test/java/io/grpc/inprocess/InProcessTransportTest.java index 420a9c4a8e7..d2220e05114 100644 --- a/inprocess/src/test/java/io/grpc/inprocess/InProcessTransportTest.java +++ b/inprocess/src/test/java/io/grpc/inprocess/InProcessTransportTest.java @@ -17,6 +17,7 @@ package io.grpc.inprocess; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import io.grpc.CallOptions; @@ -34,15 +35,25 @@ import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; import io.grpc.internal.AbstractTransportTest; +import io.grpc.internal.ClientStream; +import io.grpc.internal.ClientStreamListenerBase; import io.grpc.internal.GrpcUtil; import io.grpc.internal.InternalServer; import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.MockServerTransportListener; +import io.grpc.internal.MockServerTransportListener.StreamCreation; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerStreamListenerBase; +import io.grpc.internal.testing.TestStreamTracer; import io.grpc.stub.ClientCalls; import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.TestMethodDescriptors; +import java.io.InputStream; +import java.util.Arrays; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import org.junit.Assert; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -55,10 +66,18 @@ public class InProcessTransportTest extends AbstractTransportTest { private static final String TRANSPORT_NAME = "perfect-for-testing"; private static final String AUTHORITY = "a-testing-authority"; protected static final String USER_AGENT = "a-testing-user-agent"; + private static final int TIMEOUT_MS = 5000; + private static final long TEST_MESSAGE_LENGTH = 100; @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + @Override + protected InternalServer newServer( + int port, List streamTracerFactories) { + return newServer(streamTracerFactories); + } + @Override protected InternalServer newServer( List streamTracerFactories) { @@ -68,12 +87,6 @@ protected InternalServer newServer( return new InProcessServer(builder, streamTracerFactories); } - @Override - protected InternalServer newServer( - int port, List streamTracerFactories) { - return newServer(streamTracerFactories); - } - @Override protected String testAuthority(InternalServer server) { return AUTHORITY; @@ -83,14 +96,13 @@ protected String testAuthority(InternalServer server) { protected ManagedClientTransport newClientTransport(InternalServer server) { return new InProcessTransport( new InProcessSocketAddress(TRANSPORT_NAME), GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, - testAuthority(server), USER_AGENT, eagAttrs(), false); + testAuthority(server), USER_AGENT, eagAttrs(), false, -1); } - @Override - protected boolean sizesReported() { - // TODO(zhangkun83): InProcessTransport doesn't record metrics for now - // (https://github.com/grpc/grpc-java/issues/2284) - return false; + private ManagedClientTransport newClientTransportWithAssumedMessageSize(InternalServer server) { + return new InProcessTransport( + new InProcessSocketAddress(TRANSPORT_NAME), GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, + testAuthority(server), USER_AGENT, eagAttrs(), false, TEST_MESSAGE_LENGTH); } @Test @@ -170,11 +182,67 @@ public Listener startCall(ServerCall call, Metadata headers) { .build(); ClientCall call = channel.newCall(nonMatchMethod, CallOptions.DEFAULT); try { - ClientCalls.futureUnaryCall(call, null).get(5, TimeUnit.SECONDS); + ClientCalls.futureUnaryCall(call, null).get(TIMEOUT_MS, TimeUnit.MILLISECONDS); fail("Call should fail."); } catch (ExecutionException ex) { StatusRuntimeException s = (StatusRuntimeException)ex.getCause(); assertEquals(Code.UNIMPLEMENTED, s.getStatus().getCode()); } } + + @Test + public void basicStreamInProcess() throws Exception { + InProcessServerBuilder builder = InProcessServerBuilder + .forName(TRANSPORT_NAME) + .maxInboundMetadataSize(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); + server = new InProcessServer(builder, Arrays.asList(serverStreamTracerFactory)); + server.start(serverListener); + client = newClientTransportWithAssumedMessageSize(server); + startTransport(client, mockClientTransportListener); + MockServerTransportListener serverTransportListener + = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); + serverTransport = serverTransportListener.transport; + // Set up client stream + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), CallOptions.DEFAULT, tracers); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); + StreamCreation serverStreamCreation + = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); + ServerStream serverStream = serverStreamCreation.stream; + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; + serverStream.request(1); + assertTrue(clientStream.isReady()); + // Send message from client to server + clientStream.writeMessage(methodDescriptor.streamRequest("Hello from client")); + clientStream.flush(); + // Verify server received the message and check its size + InputStream message = + serverStreamListener.messageQueue.poll(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertEquals("Hello from client", methodDescriptor.parseRequest(message)); + message.close(); + clientStream.halfClose(); + assertAssumedMessageSize(clientStreamTracer1, serverStreamTracer1); + + clientStream.request(1); + assertTrue(serverStream.isReady()); + serverStream.writeMessage(methodDescriptor.streamResponse("Hi from server")); + serverStream.flush(); + message = clientStreamListener.messageQueue.poll(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertEquals("Hi from server", methodDescriptor.parseResponse(message)); + assertAssumedMessageSize(serverStreamTracer1, clientStreamTracer1); + message.close(); + Status status = Status.OK.withDescription("That was normal"); + serverStream.close(status, new Metadata()); + } + + private void assertAssumedMessageSize( + TestStreamTracer streamTracerSender, TestStreamTracer streamTracerReceiver) { + if (isEnabledSupportTracingMessageSizes()) { + Assert.assertEquals(TEST_MESSAGE_LENGTH, streamTracerSender.getOutboundWireSize()); + Assert.assertEquals(TEST_MESSAGE_LENGTH, streamTracerSender.getOutboundUncompressedSize()); + Assert.assertEquals(TEST_MESSAGE_LENGTH, streamTracerReceiver.getInboundWireSize()); + Assert.assertEquals(TEST_MESSAGE_LENGTH, streamTracerReceiver.getInboundUncompressedSize()); + } + } } diff --git a/inprocess/src/test/java/io/grpc/inprocess/StandaloneInProcessTransportTest.java b/inprocess/src/test/java/io/grpc/inprocess/StandaloneInProcessTransportTest.java deleted file mode 100644 index b1d80d53b8b..00000000000 --- a/inprocess/src/test/java/io/grpc/inprocess/StandaloneInProcessTransportTest.java +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.inprocess; - -import io.grpc.InternalChannelz.SocketStats; -import io.grpc.InternalInstrumented; -import io.grpc.ServerStreamTracer; -import io.grpc.internal.AbstractTransportTest; -import io.grpc.internal.GrpcUtil; -import io.grpc.internal.InternalServer; -import io.grpc.internal.ManagedClientTransport; -import io.grpc.internal.ObjectPool; -import io.grpc.internal.ServerListener; -import io.grpc.internal.ServerTransport; -import io.grpc.internal.ServerTransportListener; -import io.grpc.internal.SharedResourcePool; -import java.io.IOException; -import java.net.SocketAddress; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; -import org.junit.Ignore; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Unit tests for {@link InProcessTransport} when used with a separate {@link InternalServer}. */ -@RunWith(JUnit4.class) -public final class StandaloneInProcessTransportTest extends AbstractTransportTest { - private static final String TRANSPORT_NAME = "perfect-for-testing"; - private static final String AUTHORITY = "a-testing-authority"; - private static final String USER_AGENT = "a-testing-user-agent"; - - private final ObjectPool schedulerPool = - SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); - - private TestServer currentServer; - - @Override - protected InternalServer newServer( - List streamTracerFactories) { - return new TestServer(streamTracerFactories); - } - - @Override - protected InternalServer newServer( - int port, List streamTracerFactories) { - return newServer(streamTracerFactories); - } - - @Override - protected String testAuthority(InternalServer server) { - return AUTHORITY; - } - - @Override - protected ManagedClientTransport newClientTransport(InternalServer server) { - TestServer testServer = (TestServer) server; - return InternalInProcess.createInProcessTransport( - TRANSPORT_NAME, - GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, - testAuthority(server), - USER_AGENT, - eagAttrs(), - schedulerPool, - testServer.streamTracerFactories, - testServer.serverListener, - false); - } - - @Override - protected boolean sizesReported() { - // TODO(zhangkun83): InProcessTransport doesn't record metrics for now - // (https://github.com/grpc/grpc-java/issues/2284) - return false; - } - - @Test - @Ignore - @Override - public void socketStats() throws Exception { - // test does not apply to in-process - } - - /** An internalserver just for this test. */ - private final class TestServer implements InternalServer { - - final List streamTracerFactories; - ServerListener serverListener; - - TestServer(List streamTracerFactories) { - this.streamTracerFactories = streamTracerFactories; - } - - @Override - public void start(ServerListener serverListener) throws IOException { - if (currentServer != null) { - throw new IOException("Server already present"); - } - currentServer = this; - this.serverListener = new ServerListenerWrapper(serverListener); - } - - @Override - public void shutdown() { - currentServer = null; - serverListener.serverShutdown(); - } - - @Override - public SocketAddress getListenSocketAddress() { - return new SocketAddress() {}; - } - - @Override - public List getListenSocketAddresses() { - return Collections.singletonList(getListenSocketAddress()); - } - - @Override - @Nullable - public InternalInstrumented getListenSocketStats() { - return null; - } - - @Override - @Nullable - public List> getListenSocketStatsList() { - return null; - } - } - - /** Wraps the server listener to ensure we don't accept new transports after shutdown. */ - private static final class ServerListenerWrapper implements ServerListener { - private final ServerListener delegateListener; - private boolean shutdown; - - ServerListenerWrapper(ServerListener delegateListener) { - this.delegateListener = delegateListener; - } - - @Override - public ServerTransportListener transportCreated(ServerTransport transport) { - if (shutdown) { - return null; - } - return delegateListener.transportCreated(transport); - } - - @Override - public void serverShutdown() { - shutdown = true; - delegateListener.serverShutdown(); - } - } -} diff --git a/interop-testing/build.gradle b/interop-testing/build.gradle index a19efb00155..5160759460c 100644 --- a/interop-testing/build.gradle +++ b/interop-testing/build.gradle @@ -13,6 +13,7 @@ dependencies { implementation project(path: ':grpc-alts', configuration: 'shadow'), project(':grpc-auth'), project(':grpc-census'), + project(':grpc-opentelemetry'), project(':grpc-gcp-csm-observability'), project(':grpc-netty'), project(':grpc-okhttp'), @@ -30,7 +31,6 @@ dependencies { project(':grpc-stub'), project(':grpc-protobuf'), libraries.junit - compileOnly libraries.javax.annotation // TODO(sergiitk): replace with com.google.cloud:google-cloud-logging // Used instead of google-cloud-logging because it's failing // due to a circular dependency on grpc. @@ -52,8 +52,16 @@ dependencies { libraries.mockito.core, libraries.okhttp - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } configureProtoCompilation() diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java index 2f4dc69c0c6..22c64d12f33 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java @@ -7,9 +7,6 @@ * A service used to obtain stats for verifying LB behavior. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class LoadBalancerStatsServiceGrpc { @@ -94,6 +91,21 @@ public LoadBalancerStatsServiceStub newStub(io.grpc.Channel channel, io.grpc.Cal return LoadBalancerStatsServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static LoadBalancerStatsServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public LoadBalancerStatsServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerStatsServiceBlockingV2Stub(channel, callOptions); + } + }; + return LoadBalancerStatsServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -214,6 +226,46 @@ public void getClientAccumulatedStats(io.grpc.testing.integration.Messages.LoadB * A service used to obtain stats for verifying LB behavior. * */ + public static final class LoadBalancerStatsServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private LoadBalancerStatsServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected LoadBalancerStatsServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerStatsServiceBlockingV2Stub(channel, callOptions); + } + + /** + *

+     * Gets the backend distribution for RPCs sent by a test client.
+     * 
+ */ + public io.grpc.testing.integration.Messages.LoadBalancerStatsResponse getClientStats(io.grpc.testing.integration.Messages.LoadBalancerStatsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetClientStatsMethod(), getCallOptions(), request); + } + + /** + *
+     * Gets the accumulated stats for RPCs sent by a test client.
+     * 
+ */ + public io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsResponse getClientAccumulatedStats(io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetClientAccumulatedStatsMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service LoadBalancerStatsService. + *
+   * A service used to obtain stats for verifying LB behavior.
+   * 
+ */ public static final class LoadBalancerStatsServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private LoadBalancerStatsServiceBlockingStub( diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java index 1650365bd52..980dee010f1 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/metrics.proto") @io.grpc.stub.annotations.GrpcGenerated public final class MetricsServiceGrpc { @@ -91,6 +88,21 @@ public MetricsServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions c return MetricsServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static MetricsServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public MetricsServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new MetricsServiceBlockingV2Stub(channel, callOptions); + } + }; + return MetricsServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -201,6 +213,46 @@ public void getGauge(io.grpc.testing.integration.Metrics.GaugeRequest request, /** * A stub to allow clients to do synchronous rpc calls to service MetricsService. */ + public static final class MetricsServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private MetricsServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected MetricsServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new MetricsServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Returns the values of all the gauges that are currently being maintained by
+     * the service
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + getAllGauges(io.grpc.testing.integration.Metrics.EmptyMessage request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getGetAllGaugesMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns the value of one gauge
+     * 
+ */ + public io.grpc.testing.integration.Metrics.GaugeResponse getGauge(io.grpc.testing.integration.Metrics.GaugeRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetGaugeMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service MetricsService. + */ public static final class MetricsServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private MetricsServiceBlockingStub( diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java index d1887ee83c4..05d46ce8e95 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java @@ -7,9 +7,6 @@ * A service used to control reconnect server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ReconnectServiceGrpc { @@ -94,6 +91,21 @@ public ReconnectServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return ReconnectServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ReconnectServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ReconnectServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReconnectServiceBlockingV2Stub(channel, callOptions); + } + }; + return ReconnectServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -202,6 +214,40 @@ public void stop(io.grpc.testing.integration.EmptyProtos.Empty request, * A service used to control reconnect server. * */ + public static final class ReconnectServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ReconnectServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ReconnectServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReconnectServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty start(io.grpc.testing.integration.Messages.ReconnectParams request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getStartMethod(), getCallOptions(), request); + } + + /** + */ + public io.grpc.testing.integration.Messages.ReconnectInfo stop(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getStopMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ReconnectService. + *
+   * A service used to control reconnect server.
+   * 
+ */ public static final class ReconnectServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ReconnectServiceBlockingStub( diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/TestServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/TestServiceGrpc.java index 08071a3b653..a881c85c150 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/TestServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/TestServiceGrpc.java @@ -8,9 +8,6 @@ * performance with various types of payload. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { @@ -281,6 +278,21 @@ public TestServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions call return TestServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -551,6 +563,125 @@ public void unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty requ * performance with various types of payload. * */ + public static final class TestServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * One empty request followed by one empty response.
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty emptyCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getEmptyCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by one response.
+     * 
+ */ + public io.grpc.testing.integration.Messages.SimpleResponse unaryCall(io.grpc.testing.integration.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by one response. Response has cache control
+     * headers set such that a caching HTTP proxy (such as GFE) can
+     * satisfy subsequent requests.
+     * 
+ */ + public io.grpc.testing.integration.Messages.SimpleResponse cacheableUnaryCall(io.grpc.testing.integration.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getCacheableUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by a sequence of responses (streamed download).
+     * The server returns the payload with client desired type and sizes.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingOutputCall(io.grpc.testing.integration.Messages.StreamingOutputCallRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamingOutputCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A sequence of requests followed by one response (streamed upload).
+     * The server returns the aggregated size of client payload as the result.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingInputCall() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getStreamingInputCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests with each request served by the server immediately.
+     * As one request could lead to multiple responses, this interface
+     * demonstrates the idea of full duplexing.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + fullDuplexCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getFullDuplexCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests followed by a sequence of responses.
+     * The server buffers all the client requests and then serves them in order. A
+     * stream of responses are returned to the client when the server starts with
+     * first request.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + halfDuplexCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getHalfDuplexCallMethod(), getCallOptions()); + } + + /** + *
+     * The test server will not implement this method. It will be used
+     * to test the behavior when clients call unimplemented methods.
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnimplementedCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestService. + *
+   * A simple service to test the various types of RPCs and experiment with
+   * performance with various types of payload.
+   * 
+ */ public static final class TestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestServiceBlockingStub( diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java index 9711386185e..fdd8d5650ed 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java @@ -8,9 +8,6 @@ * that case. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class UnimplementedServiceGrpc { @@ -64,6 +61,21 @@ public UnimplementedServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOpt return UnimplementedServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static UnimplementedServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public UnimplementedServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new UnimplementedServiceBlockingV2Stub(channel, callOptions); + } + }; + return UnimplementedServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -167,6 +179,37 @@ public void unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty requ * that case. * */ + public static final class UnimplementedServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private UnimplementedServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected UnimplementedServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new UnimplementedServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * A call that no server should implement
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnimplementedCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service UnimplementedService. + *
+   * A simple service NOT implemented at servers so clients can test for
+   * that case.
+   * 
+ */ public static final class UnimplementedServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private UnimplementedServiceBlockingStub( diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java index 164119a29e7..6c019efefea 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java @@ -7,9 +7,6 @@ * A service to dynamically update the configuration of an xDS test client. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class XdsUpdateClientConfigureServiceGrpc { @@ -63,6 +60,21 @@ public XdsUpdateClientConfigureServiceStub newStub(io.grpc.Channel channel, io.g return XdsUpdateClientConfigureServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static XdsUpdateClientConfigureServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public XdsUpdateClientConfigureServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateClientConfigureServiceBlockingV2Stub(channel, callOptions); + } + }; + return XdsUpdateClientConfigureServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -162,6 +174,36 @@ public void configure(io.grpc.testing.integration.Messages.ClientConfigureReques * A service to dynamically update the configuration of an xDS test client. * */ + public static final class XdsUpdateClientConfigureServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private XdsUpdateClientConfigureServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected XdsUpdateClientConfigureServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateClientConfigureServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Update the tes client's configuration.
+     * 
+ */ + public io.grpc.testing.integration.Messages.ClientConfigureResponse configure(io.grpc.testing.integration.Messages.ClientConfigureRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getConfigureMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service XdsUpdateClientConfigureService. + *
+   * A service to dynamically update the configuration of an xDS test client.
+   * 
+ */ public static final class XdsUpdateClientConfigureServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private XdsUpdateClientConfigureServiceBlockingStub( diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java index dccd23ccbee..5531033ae5c 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java @@ -7,9 +7,6 @@ * A service to remotely control health status of an xDS test server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class XdsUpdateHealthServiceGrpc { @@ -94,6 +91,21 @@ public XdsUpdateHealthServiceStub newStub(io.grpc.Channel channel, io.grpc.CallO return XdsUpdateHealthServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static XdsUpdateHealthServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public XdsUpdateHealthServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateHealthServiceBlockingV2Stub(channel, callOptions); + } + }; + return XdsUpdateHealthServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -202,6 +214,40 @@ public void setNotServing(io.grpc.testing.integration.EmptyProtos.Empty request, * A service to remotely control health status of an xDS test server. * */ + public static final class XdsUpdateHealthServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private XdsUpdateHealthServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected XdsUpdateHealthServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateHealthServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty setServing(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSetServingMethod(), getCallOptions(), request); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty setNotServing(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSetNotServingMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service XdsUpdateHealthService. + *
+   * A service to remotely control health status of an xDS test server.
+   * 
+ */ public static final class XdsUpdateHealthServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private XdsUpdateHealthServiceBlockingStub( diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index a581c750028..51295281a90 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -119,7 +119,6 @@ import javax.annotation.Nullable; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; -import org.HdrHistogram.Histogram; import org.junit.After; import org.junit.Assert; import org.junit.Assume; @@ -133,7 +132,7 @@ /** * Abstract base class for all GRPC transport tests. * - *

New tests should avoid using Mockito to support running on AppEngine.

+ *

New tests should avoid using Mockito to support running on AppEngine. */ public abstract class AbstractInteropTest { private static Logger logger = Logger.getLogger(AbstractInteropTest.class.getName()); @@ -484,7 +483,7 @@ public void clientCompressedUnary(boolean probe) throws Exception { blockingStub.unaryCall(expectCompressedRequest); fail("expected INVALID_ARGUMENT"); } catch (StatusRuntimeException e) { - assertEquals(Status.INVALID_ARGUMENT.getCode(), e.getStatus().getCode()); + assertCodeEquals(Status.Code.INVALID_ARGUMENT, e.getStatus()); } assertStatsTrace("grpc.testing.TestService/UnaryCall", Status.Code.INVALID_ARGUMENT); } @@ -653,7 +652,7 @@ public void clientCompressedStreaming(boolean probe) throws Exception { responseObserver.awaitCompletion(operationTimeoutMillis(), TimeUnit.MILLISECONDS); Throwable e = responseObserver.getError(); assertNotNull("expected INVALID_ARGUMENT", e); - assertEquals(Status.INVALID_ARGUMENT.getCode(), Status.fromThrowable(e).getCode()); + assertCodeEquals(Status.Code.INVALID_ARGUMENT, Status.fromThrowable(e)); } // Start a new stream @@ -802,8 +801,7 @@ public void cancelAfterBegin() throws Exception { requestObserver.onError(new RuntimeException()); responseObserver.awaitCompletion(); assertEquals(Arrays.asList(), responseObserver.getValues()); - assertEquals(Status.Code.CANCELLED, - Status.fromThrowable(responseObserver.getError()).getCode()); + assertCodeEquals(Status.Code.CANCELLED, Status.fromThrowable(responseObserver.getError())); if (metricsExpected()) { MetricsRecord clientStartRecord = clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); @@ -840,8 +838,7 @@ public void cancelAfterFirstResponse() throws Exception { requestObserver.onError(new RuntimeException()); responseObserver.awaitCompletion(operationTimeoutMillis(), TimeUnit.MILLISECONDS); assertEquals(1, responseObserver.getValues().size()); - assertEquals(Status.Code.CANCELLED, - Status.fromThrowable(responseObserver.getError()).getCode()); + assertCodeEquals(Status.Code.CANCELLED, Status.fromThrowable(responseObserver.getError())); assertStatsTrace("grpc.testing.TestService/FullDuplexCall", Status.Code.CANCELLED); } @@ -1108,7 +1105,7 @@ public void deadlineExceeded() throws Exception { stub.streamingOutputCall(request).next(); fail("Expected deadline to be exceeded"); } catch (StatusRuntimeException ex) { - assertEquals(Status.DEADLINE_EXCEEDED.getCode(), ex.getStatus().getCode()); + assertCodeEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus()); String desc = ex.getStatus().getDescription(); assertTrue(desc, // There is a race between client and server-side deadline expiration. @@ -1154,8 +1151,7 @@ public void deadlineExceededServerStreaming() throws Exception { .withDeadlineAfter(30, TimeUnit.MILLISECONDS) .streamingOutputCall(request, recorder); recorder.awaitCompletion(); - assertEquals(Status.DEADLINE_EXCEEDED.getCode(), - Status.fromThrowable(recorder.getError()).getCode()); + assertCodeEquals(Status.Code.DEADLINE_EXCEEDED, Status.fromThrowable(recorder.getError())); if (metricsExpected()) { // Stream may not have been created when deadline is exceeded, thus we don't check tracer // stats. @@ -1180,7 +1176,7 @@ public void deadlineInPast() throws Exception { .emptyCall(Empty.getDefaultInstance()); fail("Should have thrown"); } catch (StatusRuntimeException ex) { - assertEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus().getCode()); + assertCodeEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus()); assertThat(ex.getStatus().getDescription()) .startsWith("ClientCall started after CallOptions deadline was exceeded"); } @@ -1213,7 +1209,7 @@ public void deadlineInPast() throws Exception { .emptyCall(Empty.getDefaultInstance()); fail("Should have thrown"); } catch (StatusRuntimeException ex) { - assertEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus().getCode()); + assertCodeEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus()); assertThat(ex.getStatus().getDescription()) .startsWith("ClientCall started after CallOptions deadline was exceeded"); } @@ -1279,8 +1275,7 @@ public void maxInboundSize_tooBig() { stub.streamingOutputCall(request).next(); fail(); } catch (StatusRuntimeException ex) { - Status s = ex.getStatus(); - assertWithMessage(s.toString()).that(s.getCode()).isEqualTo(Status.Code.RESOURCE_EXHAUSTED); + assertCodeEquals(Status.Code.RESOURCE_EXHAUSTED, ex.getStatus()); assertThat(Throwables.getStackTraceAsString(ex)).contains("exceeds maximum"); } } @@ -1335,8 +1330,7 @@ public void maxOutboundSize_tooBig() { stub.streamingOutputCall(request).next(); fail(); } catch (StatusRuntimeException ex) { - Status s = ex.getStatus(); - assertWithMessage(s.toString()).that(s.getCode()).isEqualTo(Status.Code.CANCELLED); + assertCodeEquals(Status.Code.CANCELLED, ex.getStatus()); assertThat(Throwables.getStackTraceAsString(ex)).contains("message too large"); } } @@ -1558,7 +1552,7 @@ public void statusCodeAndMessage() throws Exception { blockingStub.unaryCall(simpleRequest); fail(); } catch (StatusRuntimeException e) { - assertEquals(Status.UNKNOWN.getCode(), e.getStatus().getCode()); + assertCodeEquals(Status.Code.UNKNOWN, e.getStatus()); assertEquals(errorMessage, e.getStatus().getDescription()); } assertStatsTrace("grpc.testing.TestService/UnaryCall", Status.Code.UNKNOWN); @@ -1574,7 +1568,7 @@ public void statusCodeAndMessage() throws Exception { .isTrue(); assertThat(responseObserver.getError()).isNotNull(); Status status = Status.fromThrowable(responseObserver.getError()); - assertEquals(Status.UNKNOWN.getCode(), status.getCode()); + assertCodeEquals(Status.Code.UNKNOWN, status); assertEquals(errorMessage, status.getDescription()); assertStatsTrace("grpc.testing.TestService/FullDuplexCall", Status.Code.UNKNOWN); } @@ -1594,7 +1588,7 @@ public void specialStatusMessage() throws Exception { blockingStub.unaryCall(simpleRequest); fail(); } catch (StatusRuntimeException e) { - assertEquals(Status.UNKNOWN.getCode(), e.getStatus().getCode()); + assertCodeEquals(Status.Code.UNKNOWN, e.getStatus()); assertEquals(errorMessage, e.getStatus().getDescription()); } assertStatsTrace("grpc.testing.TestService/UnaryCall", Status.Code.UNKNOWN); @@ -1607,7 +1601,7 @@ public void unimplementedMethod() { blockingStub.unimplementedCall(Empty.getDefaultInstance()); fail(); } catch (StatusRuntimeException e) { - assertEquals(Status.UNIMPLEMENTED.getCode(), e.getStatus().getCode()); + assertCodeEquals(Status.Code.UNIMPLEMENTED, e.getStatus()); } assertClientStatsTrace("grpc.testing.TestService/UnimplementedCall", @@ -1623,7 +1617,7 @@ public void unimplementedService() { stub.unimplementedCall(Empty.getDefaultInstance()); fail(); } catch (StatusRuntimeException e) { - assertEquals(Status.UNIMPLEMENTED.getCode(), e.getStatus().getCode()); + assertCodeEquals(Status.Code.UNIMPLEMENTED, e.getStatus()); } assertStatsTrace("grpc.testing.UnimplementedService/UnimplementedCall", @@ -1631,7 +1625,6 @@ public void unimplementedService() { } /** Start a fullDuplexCall which the server will not respond, and verify the deadline expires. */ - @SuppressWarnings("MissingFail") @Test public void timeoutOnSleepingServer() throws Exception { TestServiceGrpc.TestServiceStub stub = @@ -1641,20 +1634,15 @@ public void timeoutOnSleepingServer() throws Exception { StreamObserver requestObserver = stub.fullDuplexCall(responseObserver); - StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder() + requestObserver.onNext(StreamingOutputCallRequest.newBuilder() .setPayload(Payload.newBuilder() .setBody(ByteString.copyFrom(new byte[27182]))) - .build(); - try { - requestObserver.onNext(request); - } catch (IllegalStateException expected) { - // This can happen if the stream has already been terminated due to deadline exceeded. - } + .build()); assertTrue(responseObserver.awaitCompletion(operationTimeoutMillis(), TimeUnit.MILLISECONDS)); assertEquals(0, responseObserver.getValues().size()); - assertEquals(Status.DEADLINE_EXCEEDED.getCode(), - Status.fromThrowable(responseObserver.getError()).getCode()); + assertCodeEquals( + Status.Code.DEADLINE_EXCEEDED, Status.fromThrowable(responseObserver.getError())); if (metricsExpected()) { // CensusStreamTracerModule record final status in the interceptor, thus is guaranteed to be @@ -1680,148 +1668,6 @@ public void getServerAddressAndLocalAddressFromClient() { assertNotNull(obtainLocalClientAddr()); } - private static class SoakIterationResult { - public SoakIterationResult(long latencyMs, Status status) { - this.latencyMs = latencyMs; - this.status = status; - } - - public long getLatencyMs() { - return latencyMs; - } - - public Status getStatus() { - return status; - } - - private long latencyMs = -1; - private Status status = Status.OK; - } - - private SoakIterationResult performOneSoakIteration( - TestServiceGrpc.TestServiceBlockingStub soakStub, int soakRequestSize, int soakResponseSize) - throws Exception { - long startNs = System.nanoTime(); - Status status = Status.OK; - try { - final SimpleRequest request = - SimpleRequest.newBuilder() - .setResponseSize(soakResponseSize) - .setPayload( - Payload.newBuilder().setBody(ByteString.copyFrom(new byte[soakRequestSize]))) - .build(); - final SimpleResponse goldenResponse = - SimpleResponse.newBuilder() - .setPayload( - Payload.newBuilder().setBody(ByteString.copyFrom(new byte[soakResponseSize]))) - .build(); - assertResponse(goldenResponse, soakStub.unaryCall(request)); - } catch (StatusRuntimeException e) { - status = e.getStatus(); - } - long elapsedNs = System.nanoTime() - startNs; - return new SoakIterationResult(TimeUnit.NANOSECONDS.toMillis(elapsedNs), status); - } - - /** - * Runs large unary RPCs in a loop with configurable failure thresholds - * and channel creation behavior. - */ - public void performSoakTest( - String serverUri, - boolean resetChannelPerIteration, - int soakIterations, - int maxFailures, - int maxAcceptablePerIterationLatencyMs, - int minTimeMsBetweenRpcs, - int overallTimeoutSeconds, - int soakRequestSize, - int soakResponseSize) - throws Exception { - int iterationsDone = 0; - int totalFailures = 0; - Histogram latencies = new Histogram(4 /* number of significant value digits */); - long startNs = System.nanoTime(); - ManagedChannel soakChannel = createChannel(); - TestServiceGrpc.TestServiceBlockingStub soakStub = TestServiceGrpc - .newBlockingStub(soakChannel) - .withInterceptors(recordClientCallInterceptor(clientCallCapture)); - for (int i = 0; i < soakIterations; i++) { - if (System.nanoTime() - startNs >= TimeUnit.SECONDS.toNanos(overallTimeoutSeconds)) { - break; - } - long earliestNextStartNs = System.nanoTime() - + TimeUnit.MILLISECONDS.toNanos(minTimeMsBetweenRpcs); - if (resetChannelPerIteration) { - soakChannel.shutdownNow(); - soakChannel.awaitTermination(10, TimeUnit.SECONDS); - soakChannel = createChannel(); - soakStub = TestServiceGrpc - .newBlockingStub(soakChannel) - .withInterceptors(recordClientCallInterceptor(clientCallCapture)); - } - SoakIterationResult result = - performOneSoakIteration(soakStub, soakRequestSize, soakResponseSize); - SocketAddress peer = clientCallCapture - .get().getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); - StringBuilder logStr = new StringBuilder( - String.format( - Locale.US, - "soak iteration: %d elapsed_ms: %d peer: %s server_uri: %s", - i, result.getLatencyMs(), peer != null ? peer.toString() : "null", serverUri)); - if (!result.getStatus().equals(Status.OK)) { - totalFailures++; - logStr.append(String.format(" failed: %s", result.getStatus())); - } else if (result.getLatencyMs() > maxAcceptablePerIterationLatencyMs) { - totalFailures++; - logStr.append( - " exceeds max acceptable latency: " + maxAcceptablePerIterationLatencyMs); - } else { - logStr.append(" succeeded"); - } - System.err.println(logStr.toString()); - iterationsDone++; - latencies.recordValue(result.getLatencyMs()); - long remainingNs = earliestNextStartNs - System.nanoTime(); - if (remainingNs > 0) { - TimeUnit.NANOSECONDS.sleep(remainingNs); - } - } - soakChannel.shutdownNow(); - soakChannel.awaitTermination(10, TimeUnit.SECONDS); - System.err.println( - String.format( - Locale.US, - "(server_uri: %s) soak test ran: %d / %d iterations. total failures: %d. " - + "p50: %d ms, p90: %d ms, p100: %d ms", - serverUri, - iterationsDone, - soakIterations, - totalFailures, - latencies.getValueAtPercentile(50), - latencies.getValueAtPercentile(90), - latencies.getValueAtPercentile(100))); - // check if we timed out - String timeoutErrorMessage = - String.format( - Locale.US, - "(server_uri: %s) soak test consumed all %d seconds of time and quit early, " - + "only having ran %d out of desired %d iterations.", - serverUri, - overallTimeoutSeconds, - iterationsDone, - soakIterations); - assertEquals(timeoutErrorMessage, iterationsDone, soakIterations); - // check if we had too many failures - String tooManyFailuresErrorMessage = - String.format( - Locale.US, - "(server_uri: %s) soak test total failures: %d exceeds max failures " - + "threshold: %d.", - serverUri, totalFailures, maxFailures); - assertTrue(tooManyFailuresErrorMessage, totalFailures <= maxFailures); - } - private static void assertSuccess(StreamRecorder recorder) { if (recorder.getError() != null) { throw new AssertionError(recorder.getError()); @@ -2178,6 +2024,10 @@ private void assertPayload(Payload expected, Payload actual) { } } + private static void assertCodeEquals(Status.Code expected, Status actual) { + assertWithMessage("Unexpected status: %s", actual).that(actual.getCode()).isEqualTo(expected); + } + /** * Captures the request attributes. Useful for testing ServerCalls. * {@link ServerCall#getAttributes()} diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProvider.java b/interop-testing/src/main/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProvider.java index 83c416765ec..f1410142bff 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProvider.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProvider.java @@ -110,12 +110,20 @@ protected LoadBalancer delegate() { return delegateLb; } + @Deprecated @Override public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { helper.setRpcBehavior( ((RpcBehaviorConfig) resolvedAddresses.getLoadBalancingPolicyConfig()).rpcBehavior); delegateLb.handleResolvedAddresses(resolvedAddresses); } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + helper.setRpcBehavior( + ((RpcBehaviorConfig) resolvedAddresses.getLoadBalancingPolicyConfig()).rpcBehavior); + return delegateLb.acceptResolvedAddresses(resolvedAddresses); + } } /** diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/SoakClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/SoakClient.java new file mode 100644 index 00000000000..e119c826f09 --- /dev/null +++ b/interop-testing/src/main/java/io/grpc/testing/integration/SoakClient.java @@ -0,0 +1,300 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.common.base.Function; +import com.google.protobuf.ByteString; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.Grpc; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.testing.integration.Messages.Payload; +import io.grpc.testing.integration.Messages.SimpleRequest; +import io.grpc.testing.integration.Messages.SimpleResponse; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Logger; +import org.HdrHistogram.Histogram; + +/** + * Shared implementation for rpc_soak and channel_soak. Unlike the tests in AbstractInteropTest, + * these "test cases" are only intended to be run from the command line. They don't fit the regular + * test patterns of AbstractInteropTest. + * https://github.com/grpc/grpc/blob/master/doc/interop-test-descriptions.md#rpc_soak + */ +final class SoakClient { + private static final Logger logger = Logger.getLogger(SoakClient.class.getName()); + + private static class SoakIterationResult { + public SoakIterationResult(long latencyMs, Status status) { + this.latencyMs = latencyMs; + this.status = status; + } + + public long getLatencyMs() { + return latencyMs; + } + + public Status getStatus() { + return status; + } + + private long latencyMs = -1; + private Status status = Status.OK; + } + + private static class ThreadResults { + private int threadFailures = 0; + private int iterationsDone = 0; + private Histogram latencies = new Histogram(4); + + public int getThreadFailures() { + return threadFailures; + } + + public int getIterationsDone() { + return iterationsDone; + } + + public Histogram getLatencies() { + return latencies; + } + } + + private static SoakIterationResult performOneSoakIteration( + TestServiceGrpc.TestServiceBlockingStub soakStub, int soakRequestSize, int soakResponseSize) + throws InterruptedException { + long startNs = System.nanoTime(); + Status status = Status.OK; + try { + final SimpleRequest request = + SimpleRequest.newBuilder() + .setResponseSize(soakResponseSize) + .setPayload( + Payload.newBuilder().setBody(ByteString.copyFrom(new byte[soakRequestSize]))) + .build(); + final SimpleResponse goldenResponse = + SimpleResponse.newBuilder() + .setPayload( + Payload.newBuilder().setBody(ByteString.copyFrom(new byte[soakResponseSize]))) + .build(); + assertResponse(goldenResponse, soakStub.unaryCall(request)); + } catch (StatusRuntimeException e) { + status = e.getStatus(); + } + long elapsedNs = System.nanoTime() - startNs; + return new SoakIterationResult(TimeUnit.NANOSECONDS.toMillis(elapsedNs), status); + } + + /** + * Runs large unary RPCs in a loop with configurable failure thresholds + * and channel creation behavior. + */ + public static void performSoakTest( + String serverUri, + int soakIterations, + int maxFailures, + int maxAcceptablePerIterationLatencyMs, + int minTimeMsBetweenRpcs, + int overallTimeoutSeconds, + int soakRequestSize, + int soakResponseSize, + int numThreads, + ManagedChannel sharedChannel, + Function maybeCreateChannel) + throws InterruptedException { + if (soakIterations % numThreads != 0) { + throw new IllegalArgumentException("soakIterations must be evenly divisible by numThreads."); + } + long startNs = System.nanoTime(); + Thread[] threads = new Thread[numThreads]; + int soakIterationsPerThread = soakIterations / numThreads; + List threadResultsList = new ArrayList<>(numThreads); + for (int i = 0; i < numThreads; i++) { + threadResultsList.add(new ThreadResults()); + } + for (int threadInd = 0; threadInd < numThreads; threadInd++) { + final int currentThreadInd = threadInd; + threads[threadInd] = new Thread(() -> { + try { + executeSoakTestInThread( + soakIterationsPerThread, + startNs, + minTimeMsBetweenRpcs, + soakRequestSize, + soakResponseSize, + maxAcceptablePerIterationLatencyMs, + overallTimeoutSeconds, + serverUri, + threadResultsList.get(currentThreadInd), + sharedChannel, + maybeCreateChannel); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Thread interrupted: " + e.getMessage(), e); + } + }); + threads[threadInd].start(); + } + for (Thread thread : threads) { + thread.join(); + } + + int totalFailures = 0; + int iterationsDone = 0; + Histogram latencies = new Histogram(4); + for (ThreadResults threadResult :threadResultsList) { + totalFailures += threadResult.getThreadFailures(); + iterationsDone += threadResult.getIterationsDone(); + latencies.add(threadResult.getLatencies()); + } + logger.info( + String.format( + Locale.US, + "(server_uri: %s) soak test ran: %d / %d iterations. total failures: %d. " + + "p50: %d ms, p90: %d ms, p100: %d ms", + serverUri, + iterationsDone, + soakIterations, + totalFailures, + latencies.getValueAtPercentile(50), + latencies.getValueAtPercentile(90), + latencies.getValueAtPercentile(100))); + // check if we timed out + String timeoutErrorMessage = + String.format( + Locale.US, + "(server_uri: %s) soak test consumed all %d seconds of time and quit early, " + + "only having ran %d out of desired %d iterations.", + serverUri, + overallTimeoutSeconds, + iterationsDone, + soakIterations); + assertEquals(timeoutErrorMessage, iterationsDone, soakIterations); + // check if we had too many failures + String tooManyFailuresErrorMessage = + String.format( + Locale.US, + "(server_uri: %s) soak test total failures: %d exceeds max failures " + + "threshold: %d.", + serverUri, totalFailures, maxFailures); + assertTrue(tooManyFailuresErrorMessage, totalFailures <= maxFailures); + sharedChannel.shutdownNow(); + sharedChannel.awaitTermination(10, TimeUnit.SECONDS); + } + + private static void executeSoakTestInThread( + int soakIterationsPerThread, + long startNs, + int minTimeMsBetweenRpcs, + int soakRequestSize, + int soakResponseSize, + int maxAcceptablePerIterationLatencyMs, + int overallTimeoutSeconds, + String serverUri, + ThreadResults threadResults, + ManagedChannel sharedChannel, + Function maybeCreateChannel) throws InterruptedException { + ManagedChannel currentChannel = sharedChannel; + for (int i = 0; i < soakIterationsPerThread; i++) { + if (System.nanoTime() - startNs >= TimeUnit.SECONDS.toNanos(overallTimeoutSeconds)) { + break; + } + long earliestNextStartNs = System.nanoTime() + + TimeUnit.MILLISECONDS.toNanos(minTimeMsBetweenRpcs); + // recordClientCallInterceptor takes an AtomicReference. + AtomicReference> soakThreadClientCallCapture = new AtomicReference<>(); + currentChannel = maybeCreateChannel.apply(currentChannel); + TestServiceGrpc.TestServiceBlockingStub currentStub = TestServiceGrpc + .newBlockingStub(currentChannel) + .withInterceptors(recordClientCallInterceptor(soakThreadClientCallCapture)); + SoakIterationResult result = performOneSoakIteration(currentStub, + soakRequestSize, soakResponseSize); + SocketAddress peer = soakThreadClientCallCapture + .get().getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + StringBuilder logStr = new StringBuilder( + String.format( + Locale.US, + "thread id: %d soak iteration: %d elapsed_ms: %d peer: %s server_uri: %s", + Thread.currentThread().getId(), + i, result.getLatencyMs(), peer != null ? peer.toString() : "null", serverUri)); + if (!result.getStatus().equals(Status.OK)) { + threadResults.threadFailures++; + logStr.append(String.format(" failed: %s", result.getStatus())); + logger.warning(logStr.toString()); + } else if (result.getLatencyMs() > maxAcceptablePerIterationLatencyMs) { + threadResults.threadFailures++; + logStr.append( + " exceeds max acceptable latency: " + maxAcceptablePerIterationLatencyMs); + logger.warning(logStr.toString()); + } else { + logStr.append(" succeeded"); + logger.info(logStr.toString()); + } + threadResults.iterationsDone++; + threadResults.getLatencies().recordValue(result.getLatencyMs()); + long remainingNs = earliestNextStartNs - System.nanoTime(); + if (remainingNs > 0) { + TimeUnit.NANOSECONDS.sleep(remainingNs); + } + } + } + + private static void assertResponse(SimpleResponse expected, SimpleResponse actual) { + assertPayload(expected.getPayload(), actual.getPayload()); + assertEquals(expected.getUsername(), actual.getUsername()); + assertEquals(expected.getOauthScope(), actual.getOauthScope()); + } + + private static void assertPayload(Payload expected, Payload actual) { + // Compare non deprecated fields in Payload, to make this test forward compatible. + if (expected == null || actual == null) { + assertEquals(expected, actual); + } else { + assertEquals(expected.getBody(), actual.getBody()); + } + } + + /** + * Captures the ClientCall. Useful for testing {@link ClientCall#getAttributes()} + */ + private static ClientInterceptor recordClientCallInterceptor( + final AtomicReference> clientCallCapture) { + return new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + ClientCall clientCall = next.newCall(method,callOptions); + clientCallCapture.set(clientCall); + return clientCall; + } + }; + } + +} diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java index e6829be11cb..125d876b705 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java @@ -134,6 +134,7 @@ public static void main(String[] args) throws Exception { soakIterations * soakPerIterationMaxAcceptableLatencyMs / 1000; private int soakRequestSize = 271828; private int soakResponseSize = 314159; + private int numThreads = 1; private String additionalMetadata = ""; private static LoadBalancerProvider customBackendMetricsLoadBalancerProvider; @@ -214,6 +215,8 @@ void parseArgs(String[] args) throws Exception { soakRequestSize = Integer.parseInt(value); } else if ("soak_response_size".equals(key)) { soakResponseSize = Integer.parseInt(value); + } else if ("soak_num_threads".equals(key)) { + numThreads = Integer.parseInt(value); } else if ("additional_metadata".equals(key)) { additionalMetadata = value; } else { @@ -290,6 +293,9 @@ void parseArgs(String[] args) throws Exception { + "\n --soak_response_size " + "\n The response size in a soak RPC. Default " + c.soakResponseSize + + "\n --soak_num_threads The number of threads for concurrent execution of the " + + "\n soak tests (rpc_soak or channel_soak). Default " + + c.numThreads + "\n --additional_metadata " + "\n Additional metadata to send in each request, as a " + "\n semicolon-separated list of key:value pairs. Default " @@ -517,32 +523,35 @@ private void runTest(TestCases testCase) throws Exception { } case RPC_SOAK: { - tester.performSoakTest( + SoakClient.performSoakTest( serverHost, - false /* resetChannelPerIteration */, soakIterations, soakMaxFailures, soakPerIterationMaxAcceptableLatencyMs, soakMinTimeMsBetweenRpcs, soakOverallTimeoutSeconds, soakRequestSize, - soakResponseSize); + soakResponseSize, + numThreads, + tester.createChannelBuilder().build(), + (currentChannel) -> currentChannel); break; } case CHANNEL_SOAK: { - tester.performSoakTest( + SoakClient.performSoakTest( serverHost, - true /* resetChannelPerIteration */, soakIterations, soakMaxFailures, soakPerIterationMaxAcceptableLatencyMs, soakMinTimeMsBetweenRpcs, soakOverallTimeoutSeconds, soakRequestSize, - soakResponseSize); + soakResponseSize, + numThreads, + tester.createChannelBuilder().build(), + (currentChannel) -> tester.createNewChannel(currentChannel)); break; - } case ORCA_PER_RPC: { @@ -704,6 +713,16 @@ protected ManagedChannelBuilder createChannelBuilder() { return okBuilder.intercept(createCensusStatsClientInterceptor()); } + ManagedChannel createNewChannel(ManagedChannel currentChannel) { + currentChannel.shutdownNow(); + try { + currentChannel.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while creating a new channel", e); + } + return createChannel(); + } + /** * Assuming "pick_first" policy is used, tests that all requests are sent to the same server. */ diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java index 8fa272122d0..a9ee9382495 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java @@ -18,6 +18,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.Queues; +import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.protobuf.ByteString; import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.Metadata; @@ -54,7 +55,6 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import javax.annotation.concurrent.GuardedBy; /** * Implementation of the business logic for the TestService. Uses an executor to schedule chunks diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java index 01cccf98044..fc4cdf9178f 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java @@ -180,18 +180,22 @@ void start() throws Exception { break; case IPV4: SocketAddress v4Address = Util.getV4Address(port); + InetSocketAddress localV4Address = new InetSocketAddress("127.0.0.1", port); serverBuilder = - NettyServerBuilder.forAddress(new InetSocketAddress("127.0.0.1", port), serverCreds); - if (v4Address == null) { + NettyServerBuilder.forAddress(localV4Address, serverCreds); + if (v4Address != null && !v4Address.equals(localV4Address)) { ((NettyServerBuilder) serverBuilder).addListenAddress(v4Address); } break; case IPV6: List v6Addresses = Util.getV6Addresses(port); + InetSocketAddress localV6Address = new InetSocketAddress("::1", port); serverBuilder = - NettyServerBuilder.forAddress(new InetSocketAddress("::1", port), serverCreds); + NettyServerBuilder.forAddress(localV6Address, serverCreds); for (SocketAddress address : v6Addresses) { - ((NettyServerBuilder)serverBuilder).addListenAddress(address); + if (!address.equals(localV6Address)) { + ((NettyServerBuilder) serverBuilder).addListenAddress(address); + } } break; default: diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsFederationTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsFederationTestClient.java index f55ccbdefa7..bba282b7b6f 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsFederationTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsFederationTestClient.java @@ -22,9 +22,10 @@ import io.grpc.ChannelCredentials; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; -import io.grpc.ManagedChannelBuilder; +import io.grpc.ManagedChannel; import io.grpc.alts.ComputeEngineChannelCredentials; import java.util.ArrayList; +import java.util.concurrent.TimeUnit; import java.util.logging.Logger; /** @@ -44,26 +45,8 @@ public final class XdsFederationTestClient { public static void main(String[] args) throws Exception { final XdsFederationTestClient client = new XdsFederationTestClient(); client.parseArgs(args); - Runtime.getRuntime() - .addShutdownHook( - new Thread() { - @Override - @SuppressWarnings("CatchAndPrintStackTrace") - public void run() { - System.out.println("Shutting down"); - try { - client.tearDown(); - } catch (RuntimeException e) { - e.printStackTrace(); - } - } - }); client.setUp(); - try { - client.run(); - } finally { - client.tearDown(); - } + client.run(); System.exit(0); } @@ -209,22 +192,13 @@ void setUp() { for (int i = 0; i < uris.length; i++) { clients.add(new InnerClient(creds[i], uris[i])); } - for (InnerClient c : clients) { - c.setUp(); - } - } - - private synchronized void tearDown() { - for (InnerClient c : clients) { - c.tearDown(); - } } /** * Wraps a single client stub configuration and executes a * soak test case with that configuration. */ - class InnerClient extends AbstractInteropTest { + class InnerClient { private final String credentialsType; private final String serverUri; private boolean runSucceeded = false; @@ -245,29 +219,43 @@ public boolean runSucceeded() { /** * Run the intended soak test. */ - public void run() { - boolean resetChannelPerIteration; - switch (testCase) { - case "rpc_soak": - resetChannelPerIteration = false; - break; - case "channel_soak": - resetChannelPerIteration = true; - break; - default: - throw new RuntimeException("invalid testcase: " + testCase); - } + public void run() throws InterruptedException { try { - performSoakTest( - serverUri, - resetChannelPerIteration, - soakIterations, - soakMaxFailures, - soakPerIterationMaxAcceptableLatencyMs, - soakMinTimeMsBetweenRpcs, - soakOverallTimeoutSeconds, - soakRequestSize, - soakResponseSize); + switch (testCase) { + case "rpc_soak": { + SoakClient.performSoakTest( + serverUri, + soakIterations, + soakMaxFailures, + soakPerIterationMaxAcceptableLatencyMs, + soakMinTimeMsBetweenRpcs, + soakOverallTimeoutSeconds, + soakRequestSize, + soakResponseSize, + 1, + createChannel(), + (currentChannel) -> currentChannel); + } + break; + case "channel_soak": { + SoakClient.performSoakTest( + serverUri, + soakIterations, + soakMaxFailures, + soakPerIterationMaxAcceptableLatencyMs, + soakMinTimeMsBetweenRpcs, + soakOverallTimeoutSeconds, + soakRequestSize, + soakResponseSize, + 1, + createChannel(), + (currentChannel) -> createNewChannel(currentChannel)); + } + break; + default: + throw new RuntimeException("invalid testcase: " + testCase); + } + logger.info("Test case: " + testCase + " done for server: " + serverUri); runSucceeded = true; } catch (Exception e) { @@ -276,8 +264,7 @@ public void run() { } } - @Override - protected ManagedChannelBuilder createChannelBuilder() { + ManagedChannel createChannel() { ChannelCredentials channelCredentials; switch (credentialsType) { case "compute_engine_channel_creds": @@ -291,15 +278,33 @@ protected ManagedChannelBuilder createChannelBuilder() { } return Grpc.newChannelBuilder(serverUri, channelCredentials) .keepAliveTime(3600, SECONDS) - .keepAliveTimeout(20, SECONDS); + .keepAliveTimeout(20, SECONDS) + .build(); + } + + ManagedChannel createNewChannel(ManagedChannel currentChannel) { + currentChannel.shutdownNow(); + try { + currentChannel.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while creating a new channel", e); + } + return createChannel(); } } - private void run() throws Exception { + private void run() throws InterruptedException { logger.info("Begin test case: " + testCase); ArrayList threads = new ArrayList<>(); for (InnerClient c : clients) { - Thread t = new Thread(c::run); + Thread t = new Thread(() -> { + try { + c.run(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); // Properly re-interrupt the thread + throw new RuntimeException("Thread was interrupted during execution", e); + } + }); t.start(); threads.add(t); } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java index c697bd9f305..89519041a79 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java @@ -28,6 +28,7 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; import com.google.protobuf.ByteString; +import io.grpc.BindableService; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -77,6 +78,7 @@ import java.util.logging.Logger; import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** Client for xDS interop tests. */ public final class XdsTestClient { @@ -260,6 +262,7 @@ private static RpcType parseRpc(String rpc) { } } + @IgnoreJRERequirement // OpenTelemetry uses Java 8+ APIs private void run() { if (enableCsmObservability) { csmObservability = CsmObservability.newBuilder() @@ -273,11 +276,13 @@ private void run() { .build(); csmObservability.registerGlobal(); } + @SuppressWarnings("deprecation") + BindableService oldReflectionService = ProtoReflectionService.newInstance(); statsServer = Grpc.newServerBuilderForPort(statsPort, InsecureServerCredentials.create()) .addService(new XdsStatsImpl()) .addService(new ConfigureUpdateServiceImpl()) - .addService(ProtoReflectionService.newInstance()) + .addService(oldReflectionService) .addService(ProtoReflectionServiceV1.newInstance()) .addServices(AdminInterface.getStandardServices()) .build(); diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java index 8c61f2eb2ad..88f1bf468b6 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java @@ -20,6 +20,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.protobuf.ByteString; +import io.grpc.BindableService; import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.Grpc; import io.grpc.InsecureServerCredentials; @@ -56,6 +57,7 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** Interop test server that implements the xDS testing service. */ public final class XdsTestServer { @@ -192,6 +194,7 @@ void parseArgs(String[] args) { } @SuppressWarnings("AddressSelection") + @IgnoreJRERequirement // OpenTelemetry uses Java 8+ APIs void start() throws Exception { if (enableCsmObservability) { csmObservability = CsmObservability.newBuilder() @@ -212,6 +215,8 @@ void start() throws Exception { throw new RuntimeException(e); } health = new HealthStatusManager(); + @SuppressWarnings("deprecation") + BindableService oldReflectionService = ProtoReflectionService.newInstance(); if (secureMode) { if (addressType != Util.AddressType.IPV4_IPV6) { throw new IllegalArgumentException("Secure mode only supports IPV4_IPV6 address type"); @@ -220,7 +225,7 @@ void start() throws Exception { Grpc.newServerBuilderForPort(maintenancePort, InsecureServerCredentials.create()) .addService(new XdsUpdateHealthServiceImpl(health)) .addService(health.getHealthService()) - .addService(ProtoReflectionService.newInstance()) + .addService(oldReflectionService) .addService(ProtoReflectionServiceV1.newInstance()) .addServices(AdminInterface.getStandardServices()) .build(); @@ -242,18 +247,21 @@ void start() throws Exception { break; case IPV4: SocketAddress v4Address = Util.getV4Address(port); + InetSocketAddress localV4Address = new InetSocketAddress("127.0.0.1", port); serverBuilder = NettyServerBuilder.forAddress( - new InetSocketAddress("127.0.0.1", port), insecureServerCreds); - if (v4Address != null) { + localV4Address, insecureServerCreds); + if (v4Address != null && !v4Address.equals(localV4Address) ) { ((NettyServerBuilder) serverBuilder).addListenAddress(v4Address); } break; case IPV6: List v6Addresses = Util.getV6Addresses(port); - serverBuilder = NettyServerBuilder.forAddress( - new InetSocketAddress("::1", port), insecureServerCreds); + InetSocketAddress localV6Address = new InetSocketAddress("::1", port); + serverBuilder = NettyServerBuilder.forAddress(localV6Address, insecureServerCreds); for (SocketAddress address : v6Addresses) { - ((NettyServerBuilder)serverBuilder).addListenAddress(address); + if (!address.equals(localV6Address)) { + ((NettyServerBuilder) serverBuilder).addListenAddress(address); + } } break; default: @@ -269,7 +277,7 @@ void start() throws Exception { new TestServiceImpl(serverId, host), new TestInfoInterceptor(host))) .addService(new XdsUpdateHealthServiceImpl(health)) .addService(health.getHealthService()) - .addService(ProtoReflectionService.newInstance()) + .addService(oldReflectionService) .addService(ProtoReflectionServiceV1.newInstance()) .addServices(AdminInterface.getStandardServices()) .build(); diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java index 208eb40c438..5307c26949b 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java @@ -24,6 +24,8 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; import com.google.protobuf.ByteString; import io.grpc.CallOptions; import io.grpc.Channel; @@ -53,8 +55,6 @@ import io.grpc.testing.integration.TestServiceGrpc.TestServiceBlockingStub; import io.grpc.testing.integration.TransportCompressionTest.Fzip; import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.Collection; import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -146,25 +146,16 @@ public void tearDown() { * Parameters for test. */ @Parameters - public static Collection params() { - boolean[] bools = new boolean[]{false, true}; - List combos = new ArrayList<>(64); - for (boolean enableClientMessageCompression : bools) { - for (boolean clientAcceptEncoding : bools) { - for (boolean clientEncoding : bools) { - for (boolean enableServerMessageCompression : bools) { - for (boolean serverAcceptEncoding : bools) { - for (boolean serverEncoding : bools) { - combos.add(new Object[] { - enableClientMessageCompression, clientAcceptEncoding, clientEncoding, - enableServerMessageCompression, serverAcceptEncoding, serverEncoding}); - } - } - } - } - } - } - return combos; + public static Iterable params() { + List bools = Lists.newArrayList(false, true); + return Iterables.transform(Lists.cartesianProduct( + bools, // enableClientMessageCompression + bools, // clientAcceptEncoding + bools, // clientEncoding + bools, // enableServerMessageCompression + bools, // serverAcceptEncoding + bools // serverEncoding + ), List::toArray); } @Test diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/OpenTelemetryContextPropagationTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/OpenTelemetryContextPropagationTest.java new file mode 100644 index 00000000000..3884d977a6e --- /dev/null +++ b/interop-testing/src/test/java/io/grpc/testing/integration/OpenTelemetryContextPropagationTest.java @@ -0,0 +1,191 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static org.junit.Assert.assertEquals; + +import io.grpc.ForwardingServerCallListener; +import io.grpc.InsecureServerCredentials; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.opentelemetry.GrpcOpenTelemetry; +import io.grpc.opentelemetry.GrpcTraceBinContextPropagator; +import io.grpc.opentelemetry.InternalGrpcOpenTelemetry; +import io.grpc.testing.integration.Messages.SimpleRequest; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.api.trace.propagation.W3CTraceContextPropagator; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import io.opentelemetry.context.propagation.ContextPropagators; +import io.opentelemetry.context.propagation.TextMapPropagator; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.trace.SdkTracerProvider; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Assume; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class OpenTelemetryContextPropagationTest extends AbstractInteropTest { + private final OpenTelemetrySdk openTelemetrySdk; + private final Tracer tracer; + private final GrpcOpenTelemetry grpcOpenTelemetry; + private final AtomicReference applicationSpan = new AtomicReference<>(); + private final boolean censusClient; + + @Parameterized.Parameters(name = "ContextPropagator={0}, CensusClient={1}") + public static Iterable data() { + return Arrays.asList(new Object[][] { + {W3CTraceContextPropagator.getInstance(), false}, + {GrpcTraceBinContextPropagator.defaultInstance(), false}, + {GrpcTraceBinContextPropagator.defaultInstance(), true} + }); + } + + public OpenTelemetryContextPropagationTest(TextMapPropagator textMapPropagator, + boolean isCensusClient) { + this.openTelemetrySdk = OpenTelemetrySdk.builder() + .setTracerProvider(SdkTracerProvider.builder().build()) + .setPropagators(ContextPropagators.create(TextMapPropagator.composite( + textMapPropagator + ))) + .build(); + this.tracer = openTelemetrySdk + .getTracer("grpc-java-interop-test"); + GrpcOpenTelemetry.Builder grpcOpentelemetryBuilder = GrpcOpenTelemetry.newBuilder() + .sdk(openTelemetrySdk); + InternalGrpcOpenTelemetry.enableTracing(grpcOpentelemetryBuilder, true); + grpcOpenTelemetry = grpcOpentelemetryBuilder.build(); + this.censusClient = isCensusClient; + } + + @Override + protected ServerBuilder getServerBuilder() { + NettyServerBuilder builder = NettyServerBuilder.forPort(0, InsecureServerCredentials.create()) + .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE); + builder.intercept(new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + ServerCall.Listener listener = next.startCall(call, headers); + return new ForwardingServerCallListener() { + @Override + protected ServerCall.Listener delegate() { + return listener; + } + + @Override + public void onMessage(ReqT request) { + applicationSpan.set(tracer.spanBuilder("InteropTest.Application.Span").startSpan()); + delegate().onMessage(request); + } + + @Override + public void onHalfClose() { + maybeCloseSpan(applicationSpan); + delegate().onHalfClose(); + } + + @Override + public void onCancel() { + maybeCloseSpan(applicationSpan); + delegate().onCancel(); + } + + @Override + public void onComplete() { + maybeCloseSpan(applicationSpan); + delegate().onComplete(); + } + }; + } + }); + // To ensure proper propagation of remote spans from gRPC to your application, this interceptor + // must be after any application interceptors that interact with spans. This allows the tracing + // information to be correctly passed along. However, it's fine for application-level onMessage + // handlers to access the span. + grpcOpenTelemetry.configureServerBuilder(builder); + return builder; + } + + private void maybeCloseSpan(AtomicReference applicationSpan) { + Span tmp = applicationSpan.get(); + if (tmp != null) { + tmp.end(); + } + } + + @Override + protected boolean metricsExpected() { + return false; + } + + @Override + protected ManagedChannelBuilder createChannelBuilder() { + NettyChannelBuilder builder = NettyChannelBuilder.forAddress(getListenAddress()) + .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) + .usePlaintext(); + if (!censusClient) { + // Disabling census-tracing is necessary to avoid trace ID mismatches. + // This is because census-tracing overrides the grpc-trace-bin header with + // OpenTelemetry's GrpcTraceBinPropagator. + InternalNettyChannelBuilder.setTracingEnabled(builder, false); + grpcOpenTelemetry.configureChannelBuilder(builder); + } + return builder; + } + + @Test + public void otelSpanContextPropagation() { + Assume.assumeFalse(censusClient); + Span parentSpan = tracer.spanBuilder("Test.interopTest").startSpan(); + try (Scope scope = Context.current().with(parentSpan).makeCurrent()) { + blockingStub.unaryCall(SimpleRequest.getDefaultInstance()); + } + assertEquals(parentSpan.getSpanContext().getTraceId(), + applicationSpan.get().getSpanContext().getTraceId()); + } + + @Test + @SuppressWarnings("deprecation") + public void censusToOtelGrpcTraceBinPropagator() { + Assume.assumeTrue(censusClient); + io.opencensus.trace.Tracer censusTracer = io.opencensus.trace.Tracing.getTracer(); + io.opencensus.trace.Span parentSpan = censusTracer.spanBuilder("Test.interopTest") + .startSpan(); + io.grpc.Context context = io.opencensus.trace.unsafe.ContextUtils.withValue( + io.grpc.Context.current(), parentSpan); + io.grpc.Context previous = context.attach(); + try { + blockingStub.unaryCall(SimpleRequest.getDefaultInstance()); + assertEquals(parentSpan.getContext().getTraceId().toLowerBase16(), + applicationSpan.get().getSpanContext().getTraceId()); + } finally { + context.detach(previous); + } + } +} diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/ProxyTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/ProxyTest.java index f550d657a12..725e98d0fe3 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/ProxyTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/ProxyTest.java @@ -62,7 +62,6 @@ public void shutdownTest() throws IOException { } @Test - @org.junit.Ignore // flaky. latency commonly too high public void smallLatency() throws Exception { server = new Server(); int serverPort = server.init(); diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java index edd2a57ab9d..669ce1c69db 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java @@ -303,7 +303,7 @@ public void retryUntilBufferLimitExceeded() throws Exception { serverCall.close( Status.UNAVAILABLE.withDescription("original attempt failed"), new Metadata()); - elapseBackoff(10, SECONDS); + elapseBackoff(12, SECONDS); // 2nd attempt received serverCall = serverCalls.poll(5, SECONDS); serverCall.request(2); @@ -348,7 +348,7 @@ public void statsRecorded() throws Exception { Status.UNAVAILABLE.withDescription("original attempt failed"), new Metadata()); assertRpcStatusRecorded(Status.Code.UNAVAILABLE, 1000, 1); - elapseBackoff(10, SECONDS); + elapseBackoff(12, SECONDS); assertRpcStartedRecorded(); assertOutboundMessageRecorded(); serverCall = serverCalls.poll(5, SECONDS); @@ -366,7 +366,7 @@ public void statsRecorded() throws Exception { call.request(1); assertInboundMessageRecorded(); assertInboundWireSizeRecorded(1); - assertRpcStatusRecorded(Status.Code.OK, 12000, 2); + assertRpcStatusRecorded(Status.Code.OK, 14000, 2); assertRetryStatsRecorded(1, 0, 0); } @@ -418,7 +418,7 @@ public void streamClosed(Status status) { Status.UNAVAILABLE.withDescription("original attempt failed"), new Metadata()); assertRpcStatusRecorded(Code.UNAVAILABLE, 5000, 1); - elapseBackoff(10, SECONDS); + elapseBackoff(12, SECONDS); assertRpcStartedRecorded(); assertOutboundMessageRecorded(); serverCall = serverCalls.poll(5, SECONDS); @@ -431,7 +431,7 @@ public void streamClosed(Status status) { streamClosedLatch.countDown(); // The call listener is closed. verify(mockCallListener, timeout(5000)).onClose(any(Status.class), any(Metadata.class)); - assertRpcStatusRecorded(Code.CANCELLED, 17_000, 1); + assertRpcStatusRecorded(Code.CANCELLED, 19_000, 1); assertRetryStatsRecorded(1, 0, 0); } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java index 02ede46bcdd..4a43af67ac8 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java @@ -78,6 +78,7 @@ public void parseInvalidConfig() { assertThat(status.getDescription()).contains("rpcBehavior"); } + @Deprecated @Test public void handleResolvedAddressesDelegated() { RpcBehaviorLoadBalancer lb = new RpcBehaviorLoadBalancer(new RpcBehaviorHelper(mockHelper), @@ -87,6 +88,15 @@ public void handleResolvedAddressesDelegated() { verify(mockDelegateLb).handleResolvedAddresses(resolvedAddresses); } + @Test + public void acceptResolvedAddressesDelegated() { + RpcBehaviorLoadBalancer lb = new RpcBehaviorLoadBalancer(new RpcBehaviorHelper(mockHelper), + mockDelegateLb); + ResolvedAddresses resolvedAddresses = buildResolvedAddresses(buildConfig()); + lb.acceptResolvedAddresses(resolvedAddresses); + verify(mockDelegateLb).acceptResolvedAddresses(resolvedAddresses); + } + @Test public void helperWrapsPicker() { RpcBehaviorHelper helper = new RpcBehaviorHelper(mockHelper); diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/StressTestClientTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/StressTestClientTest.java index c09a0cfeab9..a1a2cb9b5ea 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/StressTestClientTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/StressTestClientTest.java @@ -44,13 +44,13 @@ public class StressTestClientTest { @Rule - public final Timeout globalTimeout = Timeout.seconds(10); + public final Timeout globalTimeout = Timeout.seconds(15); @Test public void ipv6AddressesShouldBeSupported() { StressTestClient client = new StressTestClient(); - client.parseArgs(new String[] {"--server_addresses=[0:0:0:0:0:0:0:1]:8080," - + "[1:2:3:4:f:e:a:b]:8083"}); + client.parseArgs(new String[] { + "--server_addresses=[0:0:0:0:0:0:0:1]:8080,[1:2:3:4:f:e:a:b]:8083"}); assertEquals(2, client.addresses().size()); assertEquals(new InetSocketAddress("0:0:0:0:0:0:0:1", 8080), client.addresses().get(0)); diff --git a/istio-interop-testing/build.gradle b/istio-interop-testing/build.gradle index e2fe228f13b..083d8fcb9bf 100644 --- a/istio-interop-testing/build.gradle +++ b/istio-interop-testing/build.gradle @@ -18,8 +18,6 @@ dependencies { project(':grpc-testing'), project(':grpc-xds') - compileOnly libraries.javax.annotation - runtimeOnly libraries.netty.tcnative, libraries.netty.tcnative.classes testImplementation testFixtures(project(':grpc-api')), @@ -28,7 +26,11 @@ dependencies { libraries.junit, libraries.truth - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } sourceSets { diff --git a/istio-interop-testing/src/generated/main/grpc/io/istio/test/EchoTestServiceGrpc.java b/istio-interop-testing/src/generated/main/grpc/io/istio/test/EchoTestServiceGrpc.java index 1f48c16aed3..61d20d2f7bb 100644 --- a/istio-interop-testing/src/generated/main/grpc/io/istio/test/EchoTestServiceGrpc.java +++ b/istio-interop-testing/src/generated/main/grpc/io/istio/test/EchoTestServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: test/echo/proto/echo.proto") @io.grpc.stub.annotations.GrpcGenerated public final class EchoTestServiceGrpc { @@ -91,6 +88,21 @@ public EchoTestServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return EchoTestServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static EchoTestServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public EchoTestServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new EchoTestServiceBlockingV2Stub(channel, callOptions); + } + }; + return EchoTestServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -187,6 +199,37 @@ public void forwardEcho(io.istio.test.Echo.ForwardEchoRequest request, /** * A stub to allow clients to do synchronous rpc calls to service EchoTestService. */ + public static final class EchoTestServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private EchoTestServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected EchoTestServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new EchoTestServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.istio.test.Echo.EchoResponse echo(io.istio.test.Echo.EchoRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getEchoMethod(), getCallOptions(), request); + } + + /** + */ + public io.istio.test.Echo.ForwardEchoResponse forwardEcho(io.istio.test.Echo.ForwardEchoRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getForwardEchoMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service EchoTestService. + */ public static final class EchoTestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private EchoTestServiceBlockingStub( diff --git a/java_grpc_library.bzl b/java_grpc_library.bzl index 630ced383f7..e6afc028883 100644 --- a/java_grpc_library.bzl +++ b/java_grpc_library.bzl @@ -1,5 +1,6 @@ """Build rule for java_grpc_library.""" +load("@com_google_protobuf//bazel/common:proto_info.bzl", "ProtoInfo") load("@rules_java//java:defs.bzl", "JavaInfo", "JavaPluginInfo", "java_common") _JavaRpcToolchainInfo = provider( @@ -91,15 +92,15 @@ def _java_rpc_library_impl(ctx): srcjar = ctx.actions.declare_file("%s-proto-gensrc.jar" % ctx.label.name) args = ctx.actions.args() - args.add(toolchain.plugin.files_to_run.executable, format = "--plugin=protoc-gen-rpc-plugin=%s") + args.add(toolchain.plugin[DefaultInfo].files_to_run.executable, format = "--plugin=protoc-gen-rpc-plugin=%s") args.add("--rpc-plugin_out={0}:{1}".format(toolchain.plugin_arg, srcjar.path)) args.add_joined("--descriptor_set_in", descriptor_set_in, join_with = ctx.configuration.host_path_separator) args.add_all(srcs, map_each = _path_ignoring_repository) ctx.actions.run( - inputs = depset(srcs, transitive = [descriptor_set_in, toolchain.plugin.files]), + inputs = depset(srcs, transitive = [descriptor_set_in, toolchain.plugin[DefaultInfo].files]), outputs = [srcjar], - executable = toolchain.protoc.files_to_run, + executable = toolchain.protoc[DefaultInfo].files_to_run, arguments = [args], use_default_shell_env = True, toolchain = None, @@ -147,6 +148,33 @@ _java_grpc_library = rule( implementation = _java_rpc_library_impl, ) +# A copy of _java_grpc_library, except with a neverlink=1 _toolchain +INTERNAL_java_grpc_library_for_xds = rule( + attrs = { + "srcs": attr.label_list( + mandatory = True, + allow_empty = False, + providers = [ProtoInfo], + ), + "deps": attr.label_list( + mandatory = True, + allow_empty = False, + providers = [JavaInfo], + ), + "_toolchain": attr.label( + default = Label("//xds:java_grpc_library_toolchain"), + ), + }, + toolchains = ["@bazel_tools//tools/jdk:toolchain_type"], + fragments = ["java"], + outputs = { + "jar": "lib%{name}.jar", + "srcjar": "lib%{name}-src.jar", + }, + provides = [JavaInfo], + implementation = _java_rpc_library_impl, +) + _java_lite_grpc_library = rule( attrs = { "srcs": attr.label_list( diff --git a/lint.xml b/lint.xml new file mode 100644 index 00000000000..5b35a8d151b --- /dev/null +++ b/lint.xml @@ -0,0 +1,13 @@ + + + + + + + + diff --git a/netty/BUILD.bazel b/netty/BUILD.bazel index 9fe52ea5868..8253d1f5bff 100644 --- a/netty/BUILD.bazel +++ b/netty/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( @@ -27,6 +28,7 @@ java_library( artifact("io.netty:netty-transport"), artifact("io.netty:netty-transport-native-unix-common"), artifact("io.perfmark:perfmark-api"), + artifact("org.codehaus.mojo:animal-sniffer-annotations"), ], ) diff --git a/netty/build.gradle b/netty/build.gradle index 5533038c85b..cb97ae10b55 100644 --- a/netty/build.gradle +++ b/netty/build.gradle @@ -17,6 +17,7 @@ tasks.named("jar").configure { dependencies { api project(':grpc-api'), + libraries.animalsniffer.annotations, libraries.netty.codec.http2 implementation project(':grpc-core'), libs.netty.handler.proxy, @@ -65,8 +66,16 @@ dependencies { classifier = "linux-x86_64" } } - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } import net.ltgt.gradle.errorprone.CheckSeverity diff --git a/netty/shaded/BUILD.bazel b/netty/shaded/BUILD.bazel index 6248f406de9..0a93907bd2f 100644 --- a/netty/shaded/BUILD.bazel +++ b/netty/shaded/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") # Publicly exposed in //netty package. Purposefully does not export any symbols. diff --git a/netty/shaded/build.gradle b/netty/shaded/build.gradle index a1151ca2427..27816f9380b 100644 --- a/netty/shaded/build.gradle +++ b/netty/shaded/build.gradle @@ -9,7 +9,7 @@ plugins { id "java" id "maven-publish" - id "com.github.johnrengelman.shadow" + id "com.gradleup.shadow" id "ru.vyarus.animalsniffer" } @@ -65,8 +65,16 @@ dependencies { shadow project(':grpc-netty').configurations.runtimeClasspath.allDependencies.matching { it.group != 'io.netty' } - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("jar").configure { @@ -145,7 +153,11 @@ class NettyResourceTransformer implements Transformer { @Override boolean canTransformResource(FileTreeElement fileTreeElement) { - fileTreeElement.name.startsWith("META-INF/native-image/io.netty") + // io.netty.versions.properties can't actually be shaded successfully, + // as io.netty.util.Version still looks for the unshaded name. But we + // keep the file for manual inspection. + fileTreeElement.name.startsWith("META-INF/native-image/io.netty") || + fileTreeElement.name.startsWith("META-INF/io.netty.versions.properties") } @Override diff --git a/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java b/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java index 7f088509c04..c4ec5913cde 100644 --- a/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java +++ b/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java @@ -42,13 +42,15 @@ abstract class AbstractNettyHandler extends GrpcHttp2ConnectionHandler { private final int initialConnectionWindow; private final FlowControlPinger flowControlPing; - + protected final int maxHeaderListSize; + protected final int softLimitHeaderListSize; private boolean autoTuneFlowControlOn; private ChannelHandlerContext ctx; private boolean initialWindowSent = false; private final Ticker ticker; private static final long BDP_MEASUREMENT_PING = 1234; + protected static final int MIN_ALLOCATED_CHUNK = 16 * 1024; AbstractNettyHandler( ChannelPromise channelUnused, @@ -58,7 +60,9 @@ abstract class AbstractNettyHandler extends GrpcHttp2ConnectionHandler { ChannelLogger negotiationLogger, boolean autoFlowControl, PingLimiter pingLimiter, - Ticker ticker) { + Ticker ticker, + int maxHeaderListSize, + int softLimitHeaderListSize) { super(channelUnused, decoder, encoder, initialSettings, negotiationLogger); // During a graceful shutdown, wait until all streams are closed. @@ -73,6 +77,8 @@ abstract class AbstractNettyHandler extends GrpcHttp2ConnectionHandler { } this.flowControlPing = new FlowControlPinger(pingLimiter); this.ticker = checkNotNull(ticker, "ticker"); + this.maxHeaderListSize = maxHeaderListSize; + this.softLimitHeaderListSize = softLimitHeaderListSize; } @Override diff --git a/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java b/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java index d49e3bd672b..a3b29457670 100644 --- a/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java +++ b/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java @@ -69,13 +69,14 @@ public boolean equals(Object o) { CancelServerStreamCommand that = (CancelServerStreamCommand) o; - return Objects.equal(this.stream, that.stream) - && Objects.equal(this.reason, that.reason); + return this.stream.equals(that.stream) + && this.reason.equals(that.reason) + && this.peerNotify.equals(that.peerNotify); } @Override public int hashCode() { - return Objects.hashCode(stream, reason); + return Objects.hashCode(stream, reason, peerNotify); } @Override @@ -83,6 +84,7 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("stream", stream) .add("reason", reason) + .add("peerNotify", peerNotify) .toString(); } diff --git a/netty/src/main/java/io/grpc/netty/ClientTransportLifecycleManager.java b/netty/src/main/java/io/grpc/netty/ClientTransportLifecycleManager.java index 34f72ab97bd..01e7bc3ed12 100644 --- a/netty/src/main/java/io/grpc/netty/ClientTransportLifecycleManager.java +++ b/netty/src/main/java/io/grpc/netty/ClientTransportLifecycleManager.java @@ -19,6 +19,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.grpc.Attributes; import io.grpc.Status; +import io.grpc.internal.DisconnectError; import io.grpc.internal.ManagedClientTransport; /** Maintainer of transport lifecycle status. */ @@ -30,7 +31,6 @@ final class ClientTransportLifecycleManager { /** null iff !transportShutdown. */ private Status shutdownStatus; /** null iff !transportShutdown. */ - private Throwable shutdownThrowable; private boolean transportTerminated; public ClientTransportLifecycleManager(ManagedClientTransport.Listener listener) { @@ -56,23 +56,22 @@ public void notifyReady() { * Marks transport as shutdown, but does not set the error status. This must eventually be * followed by a call to notifyShutdown. */ - public void notifyGracefulShutdown(Status s) { + public void notifyGracefulShutdown(Status s, DisconnectError disconnectError) { if (transportShutdown) { return; } transportShutdown = true; - listener.transportShutdown(s); + listener.transportShutdown(s, disconnectError); } /** Returns {@code true} if was the first shutdown. */ @CanIgnoreReturnValue - public boolean notifyShutdown(Status s) { - notifyGracefulShutdown(s); + public boolean notifyShutdown(Status s, DisconnectError disconnectError) { + notifyGracefulShutdown(s, disconnectError); if (shutdownStatus != null) { return false; } shutdownStatus = s; - shutdownThrowable = s.asException(); return true; } @@ -84,12 +83,12 @@ public void notifyInUse(boolean inUse) { listener.transportInUse(inUse); } - public void notifyTerminated(Status s) { + public void notifyTerminated(Status s, DisconnectError disconnectError) { if (transportTerminated) { return; } transportTerminated = true; - notifyShutdown(s); + notifyShutdown(s, disconnectError); listener.transportTerminated(); } @@ -97,7 +96,4 @@ public Status getShutdownStatus() { return shutdownStatus; } - public Throwable getShutdownThrowable() { - return shutdownThrowable; - } } diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java index 3b8c595a12e..ee5227484fb 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java +++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java @@ -18,7 +18,6 @@ import static com.google.common.base.Preconditions.checkState; -import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.Internal; @@ -28,7 +27,6 @@ import io.netty.handler.codec.http2.Http2ConnectionEncoder; import io.netty.handler.codec.http2.Http2ConnectionHandler; import io.netty.handler.codec.http2.Http2Settings; -import io.netty.util.Version; import javax.annotation.Nullable; /** @@ -36,38 +34,11 @@ */ @Internal public abstract class GrpcHttp2ConnectionHandler extends Http2ConnectionHandler { - static final int ADAPTIVE_CUMULATOR_COMPOSE_MIN_SIZE_DEFAULT = 1024; - static final Cumulator ADAPTIVE_CUMULATOR = - new NettyAdaptiveCumulator(ADAPTIVE_CUMULATOR_COMPOSE_MIN_SIZE_DEFAULT); - @Nullable protected final ChannelPromise channelUnused; private final ChannelLogger negotiationLogger; - private static final boolean usingPre4_1_111_Netty; - - static { - // Netty 4.1.111 introduced a change in the behavior of duplicate() method - // that breaks the assumption of the cumulator. We need to detect this version - // and adjust the behavior accordingly. - - boolean identifiedOldVersion = false; - try { - Version version = Version.identify().get("netty-buffer"); - if (version != null) { - String[] split = version.artifactVersion().split("\\."); - if (split.length >= 3 - && Integer.parseInt(split[0]) == 4 - && Integer.parseInt(split[1]) <= 1 - && Integer.parseInt(split[2]) < 111) { - identifiedOldVersion = true; - } - } - } catch (Exception e) { - // Ignore, we'll assume it's a new version. - } - usingPre4_1_111_Netty = identifiedOldVersion; - } + @SuppressWarnings("this-escape") protected GrpcHttp2ConnectionHandler( ChannelPromise channelUnused, Http2ConnectionDecoder decoder, @@ -77,16 +48,6 @@ protected GrpcHttp2ConnectionHandler( super(decoder, encoder, initialSettings); this.channelUnused = channelUnused; this.negotiationLogger = negotiationLogger; - if (usingPre4_1_111_Netty()) { - // We need to use the adaptive cumulator only if we're using a version of Netty that - // doesn't have the behavior that breaks it. - setCumulator(ADAPTIVE_CUMULATOR); - } - } - - @VisibleForTesting - static boolean usingPre4_1_111_Netty() { - return usingPre4_1_111_Netty; } /** diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2OutboundHeaders.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2OutboundHeaders.java index 0489e135813..aabcd4fbaaa 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcHttp2OutboundHeaders.java +++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2OutboundHeaders.java @@ -66,6 +66,16 @@ private GrpcHttp2OutboundHeaders(AsciiString[] preHeaders, byte[][] serializedMe this.preHeaders = preHeaders; } + @Override + public CharSequence authority() { + for (int i = 0; i < preHeaders.length / 2; i++) { + if (preHeaders[i * 2].equals(Http2Headers.PseudoHeaderName.AUTHORITY.value())) { + return preHeaders[i * 2 + 1]; + } + } + return null; + } + @Override @SuppressWarnings("ReferenceEquality") // STATUS.value() never changes. public CharSequence status() { diff --git a/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java b/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java index 04a290165d7..f1f2c8aed71 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java +++ b/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java @@ -84,6 +84,7 @@ private GrpcSslContexts() {} private static final String SUN_PROVIDER_NAME = "SunJSSE"; private static final String IBM_PROVIDER_NAME = "IBMJSSE2"; private static final String OPENJSSE_PROVIDER_NAME = "OpenJSSE"; + private static final String BCJSSE_PROVIDER_NAME = "BCJSSE"; /** * Creates an SslContextBuilder with ciphers and APN appropriate for gRPC. @@ -199,7 +200,8 @@ public static SslContextBuilder configure(SslContextBuilder builder, Provider jd jdkProvider.getName() + " selected, but Java 9+ and Jetty NPN/ALPN unavailable"); } } else if (IBM_PROVIDER_NAME.equals(jdkProvider.getName()) - || OPENJSSE_PROVIDER_NAME.equals(jdkProvider.getName())) { + || OPENJSSE_PROVIDER_NAME.equals(jdkProvider.getName()) + || BCJSSE_PROVIDER_NAME.equals(jdkProvider.getName())) { if (JettyTlsUtil.isJava9AlpnAvailable()) { apc = ALPN; } else { @@ -255,7 +257,8 @@ private static Provider findJdkProvider() { return provider; } } else if (IBM_PROVIDER_NAME.equals(provider.getName()) - || OPENJSSE_PROVIDER_NAME.equals(provider.getName())) { + || OPENJSSE_PROVIDER_NAME.equals(provider.getName()) + || BCJSSE_PROVIDER_NAME.equals(provider.getName())) { if (JettyTlsUtil.isJava9AlpnAvailable()) { return provider; } diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index 0d309828c6d..35dc1bbc2e8 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -16,13 +16,17 @@ package io.grpc.netty; +import com.google.common.base.Optional; import io.grpc.ChannelLogger; +import io.grpc.internal.ObjectPool; import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler; import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler; import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler; import io.netty.channel.ChannelHandler; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; +import java.util.concurrent.Executor; +import javax.net.ssl.X509TrustManager; /** * Internal accessor for {@link ProtocolNegotiators}. @@ -35,9 +39,15 @@ private InternalProtocolNegotiators() {} * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. + * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ - public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) { - final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext); + public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, + ObjectPool executorPool, + Optional handshakeCompleteRunnable, + X509TrustManager extendedX509TrustManager, + String sni) { + final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, + executorPool, handshakeCompleteRunnable, extendedX509TrustManager, sni); final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @Override @@ -55,10 +65,21 @@ public void close() { negotiator.close(); } } - + return new TlsNegotiator(); } + /** + * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will + * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} + * may happen immediately, even before the TLS Handshake is complete. + */ + public static InternalProtocolNegotiator.ProtocolNegotiator tls( + SslContext sslContext, String sni, + X509TrustManager extendedX509TrustManager) { + return tls(sslContext, null, Optional.absent(), extendedX509TrustManager, sni); + } + /** * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will be * negotiated, the server TLS {@code handler} is added and writes to the {@link @@ -153,7 +174,8 @@ public static ChannelHandler grpcNegotiationHandler(GrpcHttp2ConnectionHandler n public static ChannelHandler clientTlsHandler( ChannelHandler next, SslContext sslContext, String authority, ChannelLogger negotiationLogger) { - return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger); + return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger, + Optional.absent(), null, null); } public static class ProtocolNegotiationHandler diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 305ad128454..e64f1065681 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -23,6 +23,7 @@ import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Optional; import com.google.common.base.Ticker; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.CheckReturnValue; @@ -103,6 +104,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2 0, "maxInboundMetadataSize must be > 0"); this.maxHeaderListSize = bytes; + // Clear the soft limit setting, by setting soft limit to maxInboundMetadataSize. The + // maxInboundMetadataSize will take precedence be applied before soft limit check. + this.softLimitHeaderListSize = bytes; + return this; + } + + /** + * Sets the size of metadata that clients are advised to not exceed. When a metadata with size + * larger than the soft limit is encountered there will be a probability the RPC will fail. The + * chance of failing increases as the metadata size approaches the hard limit. + * {@code Integer.MAX_VALUE} disables the enforcement. The default is implementation-dependent, + * but is not generally less than 8 KiB and may be unlimited. + * + *

This is cumulative size of the metadata. The precise calculation is + * implementation-dependent, but implementations are encouraged to follow the calculation used + * for + * HTTP/2's + * SETTINGS_MAX_HEADER_LIST_SIZE. It sums the bytes from each entry's key and value, plus 32 + * bytes of overhead per entry. + * + * @param soft the soft size limit of received metadata + * @param max the hard size limit of received metadata + * @return this + * @throws IllegalArgumentException if soft and/or max is non-positive, or max smaller than + * soft + * @since 1.68.0 + */ + @CanIgnoreReturnValue + public NettyChannelBuilder maxInboundMetadataSize(int soft, int max) { + checkArgument(soft > 0, "softLimitHeaderListSize must be > 0"); + checkArgument(max > soft, + "maxInboundMetadataSize must be greater than softLimitHeaderListSize"); + this.softLimitHeaderListSize = soft; + this.maxHeaderListSize = max; return this; } @@ -572,10 +608,22 @@ ClientTransportFactory buildTransportFactory() { ProtocolNegotiator negotiator = protocolNegotiatorFactory.newNegotiator(); return new NettyTransportFactory( - negotiator, channelFactory, channelOptions, - eventLoopGroupPool, autoFlowControl, flowControlWindow, maxInboundMessageSize, - maxHeaderListSize, keepAliveTimeNanos, keepAliveTimeoutNanos, keepAliveWithoutCalls, - transportTracerFactory, localSocketPicker, useGetForSafeMethods, transportSocketType); + negotiator, + channelFactory, + channelOptions, + eventLoopGroupPool, + autoFlowControl, + flowControlWindow, + maxInboundMessageSize, + maxHeaderListSize, + softLimitHeaderListSize, + keepAliveTimeNanos, + keepAliveTimeoutNanos, + keepAliveWithoutCalls, + transportTracerFactory, + localSocketPicker, + useGetForSafeMethods, + transportSocketType); } @VisibleForTesting @@ -604,7 +652,7 @@ static ProtocolNegotiator createProtocolNegotiatorByType( case PLAINTEXT_UPGRADE: return ProtocolNegotiators.plaintextUpgrade(); case TLS: - return ProtocolNegotiators.tls(sslContext, executorPool); + return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null, null); default: throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType); } @@ -709,6 +757,7 @@ private static final class NettyTransportFactory implements ClientTransportFacto private final int flowControlWindow; private final int maxMessageSize; private final int maxHeaderListSize; + private final int softLimitHeaderListSize; private final long keepAliveTimeNanos; private final AtomicBackoff keepAliveBackoff; private final long keepAliveTimeoutNanos; @@ -723,11 +772,20 @@ private static final class NettyTransportFactory implements ClientTransportFacto NettyTransportFactory( ProtocolNegotiator protocolNegotiator, ChannelFactory channelFactory, - Map, ?> channelOptions, ObjectPool groupPool, - boolean autoFlowControl, int flowControlWindow, int maxMessageSize, int maxHeaderListSize, - long keepAliveTimeNanos, long keepAliveTimeoutNanos, boolean keepAliveWithoutCalls, - TransportTracer.Factory transportTracerFactory, LocalSocketPicker localSocketPicker, - boolean useGetForSafeMethods, Class transportSocketType) { + Map, ?> channelOptions, + ObjectPool groupPool, + boolean autoFlowControl, + int flowControlWindow, + int maxMessageSize, + int maxHeaderListSize, + int softLimitHeaderListSize, + long keepAliveTimeNanos, + long keepAliveTimeoutNanos, + boolean keepAliveWithoutCalls, + TransportTracer.Factory transportTracerFactory, + LocalSocketPicker localSocketPicker, + boolean useGetForSafeMethods, + Class transportSocketType) { this.protocolNegotiator = checkNotNull(protocolNegotiator, "protocolNegotiator"); this.channelFactory = channelFactory; this.channelOptions = new HashMap, Object>(channelOptions); @@ -737,6 +795,7 @@ private static final class NettyTransportFactory implements ClientTransportFacto this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; + this.softLimitHeaderListSize = softLimitHeaderListSize; this.keepAliveTimeNanos = keepAliveTimeNanos; this.keepAliveBackoff = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos); this.keepAliveTimeoutNanos = keepAliveTimeoutNanos; @@ -759,6 +818,7 @@ public ConnectionClientTransport newClientTransport( serverAddress = proxiedAddr.getTargetAddress(); localNegotiator = ProtocolNegotiators.httpProxy( proxiedAddr.getProxyAddress(), + proxiedAddr.getHeaders(), proxiedAddr.getUsername(), proxiedAddr.getPassword(), protocolNegotiator); @@ -773,13 +833,31 @@ public void run() { }; // TODO(carl-mastrangelo): Pass channelLogger in. - NettyClientTransport transport = new NettyClientTransport( - serverAddress, channelFactory, channelOptions, group, - localNegotiator, autoFlowControl, flowControlWindow, - maxMessageSize, maxHeaderListSize, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos, - keepAliveWithoutCalls, options.getAuthority(), options.getUserAgent(), - tooManyPingsRunnable, transportTracerFactory.create(), options.getEagAttributes(), - localSocketPicker, channelLogger, useGetForSafeMethods, Ticker.systemTicker()); + NettyClientTransport transport = + new NettyClientTransport( + serverAddress, + channelFactory, + channelOptions, + group, + localNegotiator, + autoFlowControl, + flowControlWindow, + maxMessageSize, + maxHeaderListSize, + softLimitHeaderListSize, + keepAliveTimeNanosState.get(), + keepAliveTimeoutNanos, + keepAliveWithoutCalls, + options.getAuthority(), + options.getUserAgent(), + tooManyPingsRunnable, + transportTracerFactory.create(), + options.getEagAttributes(), + localSocketPicker, + channelLogger, + useGetForSafeMethods, + options.getMetricRecorder(), + Ticker.systemTicker()); return transport; } @@ -795,11 +873,24 @@ public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials ch if (result.error != null) { return null; } - ClientTransportFactory factory = new NettyTransportFactory( - result.negotiator.newNegotiator(), channelFactory, channelOptions, groupPool, - autoFlowControl, flowControlWindow, maxMessageSize, maxHeaderListSize, keepAliveTimeNanos, - keepAliveTimeoutNanos, keepAliveWithoutCalls, transportTracerFactory, localSocketPicker, - useGetForSafeMethods, transportSocketType); + ClientTransportFactory factory = + new NettyTransportFactory( + result.negotiator.newNegotiator(), + channelFactory, + channelOptions, + groupPool, + autoFlowControl, + flowControlWindow, + maxMessageSize, + maxHeaderListSize, + softLimitHeaderListSize, + keepAliveTimeNanos, + keepAliveTimeoutNanos, + keepAliveWithoutCalls, + transportTracerFactory, + localSocketPicker, + useGetForSafeMethods, + transportSocketType); return new SwapChannelCredentialsResult(factory, result.callCredentials); } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index eb4dbf8cc66..14a1d7535ad 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -28,16 +28,21 @@ import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.InternalChannelz; +import io.grpc.InternalStatus; import io.grpc.Metadata; +import io.grpc.MetricRecorder; import io.grpc.Status; import io.grpc.StatusException; import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.ClientTransport.PingCallback; +import io.grpc.internal.DisconnectError; +import io.grpc.internal.GoAwayDisconnectError; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; import io.grpc.internal.Http2Ping; import io.grpc.internal.InUseStateAggregator; import io.grpc.internal.KeepAliveManager; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.internal.TransportTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ClientHeadersDecoder; import io.netty.buffer.ByteBuf; @@ -77,12 +82,14 @@ import io.netty.handler.codec.http2.Http2Stream; import io.netty.handler.codec.http2.Http2StreamVisitor; import io.netty.handler.codec.http2.StreamBufferingEncoder; -import io.netty.handler.codec.http2.WeightedFairQueueByteDistributor; +import io.netty.handler.codec.http2.UniformStreamByteDistributor; import io.netty.handler.logging.LogLevel; import io.perfmark.PerfMark; import io.perfmark.Tag; import io.perfmark.TaskCloseable; import java.nio.channels.ClosedChannelException; +import java.util.LinkedHashMap; +import java.util.Map; import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; @@ -94,6 +101,8 @@ */ class NettyClientHandler extends AbstractNettyHandler { private static final Logger logger = Logger.getLogger(NettyClientHandler.class.getName()); + static boolean enablePerRpcAuthorityCheck = + GrpcUtil.getFlag("GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK", false); /** * A message that simply passes through the channel without any real processing. It is useful to @@ -115,6 +124,7 @@ class NettyClientHandler extends AbstractNettyHandler { private final Supplier stopwatchFactory; private final TransportTracer transportTracer; private final Attributes eagAttributes; + private final TcpMetrics tcpMetrics; private final String authority; private final InUseStateAggregator inUseState = new InUseStateAggregator() { @@ -128,6 +138,13 @@ protected void handleNotInUse() { lifecycleManager.notifyInUse(false); } }; + private final Map peerVerificationResults = + new LinkedHashMap() { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > 100; + } + }; private WriteQueue clientWriteQueue; private Http2Ping ping; @@ -142,13 +159,15 @@ static NettyClientHandler newHandler( boolean autoFlowControl, int flowControlWindow, int maxHeaderListSize, + int softLimitHeaderListSize, Supplier stopwatchFactory, Runnable tooManyPingsRunnable, TransportTracer transportTracer, Attributes eagAttributes, String authority, ChannelLogger negotiationLogger, - Ticker ticker) { + Ticker ticker, + MetricRecorder metricRecorder) { Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive"); Http2HeadersDecoder headersDecoder = new GrpcHttp2ClientHeadersDecoder(maxHeaderListSize); Http2FrameReader frameReader = new DefaultHttp2FrameReader(headersDecoder); @@ -156,8 +175,8 @@ static NettyClientHandler newHandler( Http2HeadersEncoder.NEVER_SENSITIVE, false, 16, Integer.MAX_VALUE); Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter(encoder); Http2Connection connection = new DefaultHttp2Connection(false); - WeightedFairQueueByteDistributor dist = new WeightedFairQueueByteDistributor(connection); - dist.allocationQuantum(16 * 1024); // Make benchmarks fast again. + UniformStreamByteDistributor dist = new UniformStreamByteDistributor(connection); + dist.minAllocationChunk(MIN_ALLOCATED_CHUNK); // Increased for benchmarks performance. DefaultHttp2RemoteFlowController controller = new DefaultHttp2RemoteFlowController(connection, dist); connection.remote().flowController(controller); @@ -171,13 +190,15 @@ static NettyClientHandler newHandler( autoFlowControl, flowControlWindow, maxHeaderListSize, + softLimitHeaderListSize, stopwatchFactory, tooManyPingsRunnable, transportTracer, eagAttributes, authority, negotiationLogger, - ticker); + ticker, + metricRecorder); } @VisibleForTesting @@ -190,18 +211,22 @@ static NettyClientHandler newHandler( boolean autoFlowControl, int flowControlWindow, int maxHeaderListSize, + int softLimitHeaderListSize, Supplier stopwatchFactory, Runnable tooManyPingsRunnable, TransportTracer transportTracer, Attributes eagAttributes, String authority, ChannelLogger negotiationLogger, - Ticker ticker) { + Ticker ticker, + MetricRecorder metricRecorder) { Preconditions.checkNotNull(connection, "connection"); Preconditions.checkNotNull(frameReader, "frameReader"); Preconditions.checkNotNull(lifecycleManager, "lifecycleManager"); Preconditions.checkArgument(flowControlWindow > 0, "flowControlWindow must be positive"); Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive"); + Preconditions.checkArgument(softLimitHeaderListSize > 0, + "softLimitHeaderListSize must be positive"); Preconditions.checkNotNull(stopwatchFactory, "stopwatchFactory"); Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable"); Preconditions.checkNotNull(eagAttributes, "eagAttributes"); @@ -247,7 +272,10 @@ static NettyClientHandler newHandler( authority, autoFlowControl, pingCounter, - ticker); + ticker, + maxHeaderListSize, + softLimitHeaderListSize, + metricRecorder); } private NettyClientHandler( @@ -264,9 +292,21 @@ private NettyClientHandler( String authority, boolean autoFlowControl, PingLimiter pingLimiter, - Ticker ticker) { - super(/* channelUnused= */ null, decoder, encoder, settings, - negotiationLogger, autoFlowControl, pingLimiter, ticker); + Ticker ticker, + int maxHeaderListSize, + int softLimitHeaderListSize, + MetricRecorder metricRecorder) { + super( + /* channelUnused= */ null, + decoder, + encoder, + settings, + negotiationLogger, + autoFlowControl, + pingLimiter, + ticker, + maxHeaderListSize, + softLimitHeaderListSize); this.lifecycleManager = lifecycleManager; this.keepAliveManager = keepAliveManager; this.stopwatchFactory = stopwatchFactory; @@ -275,6 +315,7 @@ private NettyClientHandler( this.authority = authority; this.attributes = Attributes.newBuilder() .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttributes).build(); + this.tcpMetrics = new TcpMetrics(metricRecorder); // Set the frame listener on the decoder. decoder().frameListener(new FrameListener()); @@ -380,6 +421,28 @@ private void onHeadersRead(int streamId, Http2Headers headers, boolean endStream if (streamId != Http2CodecUtil.HTTP_UPGRADE_STREAM_ID) { NettyClientStream.TransportState stream = clientStream(requireHttp2Stream(streamId)); PerfMark.event("NettyClientHandler.onHeadersRead", stream.tag()); + // check metadata size vs soft limit + int h2HeadersSize = Utils.getH2HeadersSize(headers); + boolean shouldFail = + Utils.shouldRejectOnMetadataSizeSoftLimitExceeded( + h2HeadersSize, softLimitHeaderListSize, maxHeaderListSize); + if (shouldFail && endStream) { + stream.transportReportStatus(Status.RESOURCE_EXHAUSTED + .withDescription( + String.format( + "Server Status + Trailers of size %d exceeded Metadata size soft limit: %d", + h2HeadersSize, + softLimitHeaderListSize)), true, new Metadata()); + return; + } else if (shouldFail) { + stream.transportReportStatus(Status.RESOURCE_EXHAUSTED + .withDescription( + String.format( + "Server Headers of size %d exceeded Metadata size soft limit: %d", + h2HeadersSize, + softLimitHeaderListSize)), true, new Metadata()); + return; + } stream.transportHeadersReceived(headers, endStream); } @@ -423,10 +486,12 @@ private void onRstStreamRead(int streamId, long errorCode) { @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + tcpMetrics.recordTcpInfo(ctx.channel()); logger.fine("Network channel being closed by the application."); if (ctx.channel().isActive()) { // Ignore notification that the socket was closed lifecycleManager.notifyShutdown( - Status.UNAVAILABLE.withDescription("Transport closed for unknown reason")); + Status.UNAVAILABLE.withDescription("Transport closed for unknown reason"), + SimpleDisconnectError.UNKNOWN); } super.close(ctx, promise); } @@ -434,12 +499,19 @@ public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exce /** * Handler for the Channel shutting down. */ + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + tcpMetrics.channelActive(ctx.channel()); + super.channelActive(ctx); + } + @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { try { logger.fine("Network channel is closed"); + tcpMetrics.channelInactive(ctx.channel()); Status status = Status.UNAVAILABLE.withDescription("Network closed for unknown reason"); - lifecycleManager.notifyShutdown(status); + lifecycleManager.notifyShutdown(status, SimpleDisconnectError.UNKNOWN); final Status streamStatus; if (channelInactiveReason != null) { streamStatus = channelInactiveReason; @@ -447,7 +519,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { streamStatus = lifecycleManager.getShutdownStatus(); } try { - cancelPing(lifecycleManager.getShutdownThrowable()); + cancelPing(lifecycleManager.getShutdownStatus()); // Report status to the application layer for any open streams connection().forEachActiveStream(new Http2StreamVisitor() { @Override @@ -460,7 +532,7 @@ public boolean visit(Http2Stream stream) throws Http2Exception { } }); } finally { - lifecycleManager.notifyTerminated(status); + lifecycleManager.notifyTerminated(status, SimpleDisconnectError.UNKNOWN); } } finally { // Close any open streams @@ -508,7 +580,8 @@ InternalChannelz.Security getSecurityInfo() { protected void onConnectionError(ChannelHandlerContext ctx, boolean outbound, Throwable cause, Http2Exception http2Ex) { logger.log(Level.FINE, "Caught a connection error", cause); - lifecycleManager.notifyShutdown(Utils.statusFromThrowable(cause)); + lifecycleManager.notifyShutdown(Utils.statusFromThrowable(cause), + SimpleDisconnectError.SOCKET_ERROR); // Parent class will shut down the Channel super.onConnectionError(ctx, outbound, cause, http2Ex); } @@ -541,16 +614,67 @@ protected boolean isGracefulShutdownComplete() { */ private void createStream(CreateStreamCommand command, ChannelPromise promise) throws Exception { - if (lifecycleManager.getShutdownThrowable() != null) { + if (lifecycleManager.getShutdownStatus() != null) { command.stream().setNonExistent(); // The connection is going away (it is really the GOAWAY case), // just terminate the stream now. command.stream().transportReportStatus( lifecycleManager.getShutdownStatus(), RpcProgress.MISCARRIED, true, new Metadata()); - promise.setFailure(lifecycleManager.getShutdownThrowable()); + promise.setFailure(InternalStatus.asRuntimeExceptionWithoutStacktrace( + lifecycleManager.getShutdownStatus(), null)); return; } + CharSequence authorityHeader = command.headers().authority(); + if (authorityHeader == null) { + Status authorityVerificationStatus = Status.UNAVAILABLE.withDescription( + "Missing authority header"); + command.stream().setNonExistent(); + command.stream().transportReportStatus( + Status.UNAVAILABLE, RpcProgress.PROCESSED, true, new Metadata()); + promise.setFailure(InternalStatus.asRuntimeExceptionWithoutStacktrace( + authorityVerificationStatus, null)); + return; + } + // No need to verify authority for the rpc outgoing header if it is same as the authority + // for the transport + if (!authority.contentEquals(authorityHeader)) { + Status authorityVerificationStatus = peerVerificationResults.get( + authorityHeader.toString()); + if (authorityVerificationStatus == null) { + if (attributes.get(GrpcAttributes.ATTR_AUTHORITY_VERIFIER) == null) { + authorityVerificationStatus = Status.UNAVAILABLE.withDescription( + "Authority verifier not found to verify authority"); + command.stream().setNonExistent(); + command.stream().transportReportStatus( + authorityVerificationStatus, RpcProgress.PROCESSED, true, new Metadata()); + promise.setFailure(InternalStatus.asRuntimeExceptionWithoutStacktrace( + authorityVerificationStatus, null)); + return; + } + authorityVerificationStatus = attributes.get(GrpcAttributes.ATTR_AUTHORITY_VERIFIER) + .verifyAuthority(authorityHeader.toString()); + peerVerificationResults.put(authorityHeader.toString(), authorityVerificationStatus); + if (!authorityVerificationStatus.isOk() && !enablePerRpcAuthorityCheck) { + logger.log(Level.WARNING, String.format("%s.%s", + authorityVerificationStatus.getDescription(), + enablePerRpcAuthorityCheck + ? "" : " This will be an error in the future."), + InternalStatus.asRuntimeExceptionWithoutStacktrace( + authorityVerificationStatus, null)); + } + } + if (!authorityVerificationStatus.isOk()) { + if (enablePerRpcAuthorityCheck) { + command.stream().setNonExistent(); + command.stream().transportReportStatus( + authorityVerificationStatus, RpcProgress.PROCESSED, true, new Metadata()); + promise.setFailure(InternalStatus.asRuntimeExceptionWithoutStacktrace( + authorityVerificationStatus, null)); + return; + } + } + } // Get the stream ID for the new stream. int streamId; try { @@ -564,7 +688,7 @@ private void createStream(CreateStreamCommand command, ChannelPromise promise) if (!connection().goAwaySent()) { logger.fine("Stream IDs have been exhausted for this connection. " + "Initiating graceful shutdown of the connection."); - lifecycleManager.notifyShutdown(e.getStatus()); + lifecycleManager.notifyShutdown(e.getStatus(), SimpleDisconnectError.UNKNOWN); close(ctx(), ctx().newPromise()); } return; @@ -635,14 +759,19 @@ public void operationComplete(ChannelFuture future) throws Exception { // Attach the client stream to the HTTP/2 stream object as user data. stream.setHttp2Stream(http2Stream); + promise.setSuccess(); + } else { + // Otherwise, the stream has been cancelled and Netty is sending a + // RST_STREAM frame which causes it to purge pending writes from the + // flow-controller and delete the http2Stream. The stream listener has already + // been notified of cancellation so there is nothing to do. + // + // This process has been observed to fail in some circumstances, leaving listeners + // unanswered. Ensure that some exception has been delivered consistent with the + // implied RST_STREAM result above. + Status status = Status.INTERNAL.withDescription("unknown stream for connection"); + promise.setFailure(status.asRuntimeException()); } - // Otherwise, the stream has been cancelled and Netty is sending a - // RST_STREAM frame which causes it to purge pending writes from the - // flow-controller and delete the http2Stream. The stream listener has already - // been notified of cancellation so there is nothing to do. - - // Just forward on the success status to the original promise. - promise.setSuccess(); } else { Throwable cause = future.cause(); if (cause instanceof StreamBufferingEncoder.Http2GoAwayException) { @@ -665,6 +794,19 @@ public void operationComplete(ChannelFuture future) throws Exception { } } }); + // When the HEADERS are not buffered because of MAX_CONCURRENT_STREAMS in + // StreamBufferingEncoder, the stream is created immediately even if the bytes of the HEADERS + // are delayed because the OS may have too much buffered and isn't accepting the write. The + // write promise is also delayed until flush(). However, we need to associate the netty stream + // with the transport state so that goingAway() and forcefulClose() and able to notify the + // stream of failures. + // + // This leaves a hole when MAX_CONCURRENT_STREAMS is reached, as http2Stream will be null, but + // it is better than nothing. + Http2Stream http2Stream = connection().stream(streamId); + if (http2Stream != null) { + http2Stream.setProperty(streamKey, stream); + } } /** @@ -750,19 +892,21 @@ private void sendPingFrameTraced(ChannelHandlerContext ctx, SendPingCommand msg, public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { transportTracer.reportKeepAliveSent(); - } else { - Throwable cause = future.cause(); - if (cause instanceof ClosedChannelException) { - cause = lifecycleManager.getShutdownThrowable(); - if (cause == null) { - cause = Status.UNKNOWN.withDescription("Ping failed but for unknown reason.") - .withCause(future.cause()).asException(); - } - } - finalPing.failed(cause); - if (ping == finalPing) { - ping = null; + return; + } + Throwable cause = future.cause(); + Status status = lifecycleManager.getShutdownStatus(); + if (cause instanceof ClosedChannelException) { + if (status == null) { + status = Status.UNKNOWN.withDescription("Ping failed but for unknown reason.") + .withCause(future.cause()); } + } else { + status = Utils.statusFromThrowable(cause); + } + finalPing.failed(status); + if (ping == finalPing) { + ping = null; } } }); @@ -770,7 +914,7 @@ public void operationComplete(ChannelFuture future) throws Exception { private void gracefulClose(ChannelHandlerContext ctx, GracefulCloseCommand msg, ChannelPromise promise) throws Exception { - lifecycleManager.notifyShutdown(msg.getStatus()); + lifecycleManager.notifyShutdown(msg.getStatus(), SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // Explicitly flush to create any buffered streams before sending GOAWAY. // TODO(ejona): determine if the need to flush is a bug in Netty flush(ctx); @@ -806,13 +950,15 @@ public boolean visit(Http2Stream stream) throws Http2Exception { private void goingAway(long errorCode, byte[] debugData) { Status finalStatus = statusFromH2Error( Status.Code.UNAVAILABLE, "GOAWAY shut down transport", errorCode, debugData); - lifecycleManager.notifyGracefulShutdown(finalStatus); + DisconnectError disconnectError = new GoAwayDisconnectError( + GrpcUtil.Http2Error.forCode(errorCode)); + lifecycleManager.notifyGracefulShutdown(finalStatus, disconnectError); abruptGoAwayStatus = statusFromH2Error( Status.Code.UNAVAILABLE, "Abrupt GOAWAY closed unsent stream", errorCode, debugData); // While this _should_ be UNAVAILABLE, Netty uses the wrong stream id in the GOAWAY when it // fails streams due to HPACK failures (e.g., header list too large). To be more conservative, // we assume any sent streams may be related to the GOAWAY. This should rarely impact users - // since the main time servers should use abrupt GOAWAYs is if there is a protocol error, and if + // since the main time servers should use abrupt GOAWAYs if there is a protocol error, and if // there wasn't a protocol error the error code was probably NO_ERROR which is mapped to // UNAVAILABLE. https://github.com/netty/netty/issues/10670 final Status abruptGoAwayStatusConservative = statusFromH2Error( @@ -827,7 +973,7 @@ private void goingAway(long errorCode, byte[] debugData) { // This can cause reentrancy, but should be minor since it is normal to handle writes in // response to a read. Also, the call stack is rather shallow at this point clientWriteQueue.drainNow(); - if (lifecycleManager.notifyShutdown(finalStatus)) { + if (lifecycleManager.notifyShutdown(finalStatus, disconnectError)) { // This is for the only RPCs that are actually covered by the GOAWAY error code. All other // RPCs were not observed by the remote and so should be UNAVAILABLE. channelInactiveReason = statusFromH2Error( @@ -861,9 +1007,9 @@ public boolean visit(Http2Stream stream) throws Http2Exception { } } - private void cancelPing(Throwable t) { + private void cancelPing(Status s) { if (ping != null) { - ping.failed(t); + ping.failed(s); ping = null; } } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java index a82470cacb4..6585df42df3 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java @@ -34,14 +34,17 @@ import io.grpc.InternalLogId; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MetricRecorder; import io.grpc.Status; import io.grpc.internal.ClientStream; import io.grpc.internal.ConnectionClientTransport; +import io.grpc.internal.DisconnectError; import io.grpc.internal.FailingClientStream; import io.grpc.internal.GrpcUtil; import io.grpc.internal.Http2Ping; import io.grpc.internal.KeepAliveManager; import io.grpc.internal.KeepAliveManager.ClientKeepAlivePinger; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; @@ -68,7 +71,8 @@ /** * A Netty-based {@link ConnectionClientTransport} implementation. */ -class NettyClientTransport implements ConnectionClientTransport { +class NettyClientTransport implements ConnectionClientTransport, + ClientKeepAlivePinger.TransportWithDisconnectReason { private final InternalLogId logId; private final Map, ?> channelOptions; @@ -83,6 +87,7 @@ class NettyClientTransport implements ConnectionClientTransport { private final int flowControlWindow; private final int maxMessageSize; private final int maxHeaderListSize; + private final int softLimitHeaderListSize; private KeepAliveManager keepAliveManager; private final long keepAliveTimeNanos; private final long keepAliveTimeoutNanos; @@ -104,17 +109,33 @@ class NettyClientTransport implements ConnectionClientTransport { private final ChannelLogger channelLogger; private final boolean useGetForSafeMethods; private final Ticker ticker; + private final MetricRecorder metricRecorder; + NettyClientTransport( - SocketAddress address, ChannelFactory channelFactory, - Map, ?> channelOptions, EventLoopGroup group, - ProtocolNegotiator negotiator, boolean autoFlowControl, int flowControlWindow, - int maxMessageSize, int maxHeaderListSize, - long keepAliveTimeNanos, long keepAliveTimeoutNanos, - boolean keepAliveWithoutCalls, String authority, @Nullable String userAgent, - Runnable tooManyPingsRunnable, TransportTracer transportTracer, Attributes eagAttributes, - LocalSocketPicker localSocketPicker, ChannelLogger channelLogger, - boolean useGetForSafeMethods, Ticker ticker) { + SocketAddress address, + ChannelFactory channelFactory, + Map, ?> channelOptions, + EventLoopGroup group, + ProtocolNegotiator negotiator, + boolean autoFlowControl, + int flowControlWindow, + int maxMessageSize, + int maxHeaderListSize, + int softLimitHeaderListSize, + long keepAliveTimeNanos, + long keepAliveTimeoutNanos, + boolean keepAliveWithoutCalls, + String authority, + @Nullable String userAgent, + Runnable tooManyPingsRunnable, + TransportTracer transportTracer, + Attributes eagAttributes, + LocalSocketPicker localSocketPicker, + ChannelLogger channelLogger, + boolean useGetForSafeMethods, + MetricRecorder metricRecorder, + Ticker ticker) { this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator"); this.negotiationScheme = this.negotiator.scheme(); @@ -126,6 +147,7 @@ class NettyClientTransport implements ConnectionClientTransport { this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; + this.softLimitHeaderListSize = softLimitHeaderListSize; this.keepAliveTimeNanos = keepAliveTimeNanos; this.keepAliveTimeoutNanos = keepAliveTimeoutNanos; this.keepAliveWithoutCalls = keepAliveWithoutCalls; @@ -140,6 +162,7 @@ class NettyClientTransport implements ConnectionClientTransport { this.logId = InternalLogId.allocate(getClass(), remoteAddress.toString()); this.channelLogger = Preconditions.checkNotNull(channelLogger, "channelLogger"); this.useGetForSafeMethods = useGetForSafeMethods; + this.metricRecorder = metricRecorder; this.ticker = Preconditions.checkNotNull(ticker, "ticker"); } @@ -149,7 +172,7 @@ public void ping(final PingCallback callback, final Executor executor) { executor.execute(new Runnable() { @Override public void run() { - callback.onFailure(statusExplainingWhyTheChannelIsNull.asException()); + callback.onFailure(statusExplainingWhyTheChannelIsNull); } }); return; @@ -161,7 +184,7 @@ public void run() { public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { Status s = statusFromFailedFuture(future); - Http2Ping.notifyFailed(callback, executor, s.asException()); + Http2Ping.notifyFailed(callback, executor, s); } } }; @@ -215,23 +238,25 @@ public Runnable start(Listener transportListener) { EventLoop eventLoop = group.next(); if (keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED) { keepAliveManager = new KeepAliveManager( - new ClientKeepAlivePinger(this), eventLoop, keepAliveTimeNanos, keepAliveTimeoutNanos, - keepAliveWithoutCalls); + new ClientKeepAlivePinger(this), eventLoop, keepAliveTimeNanos, + keepAliveTimeoutNanos, keepAliveWithoutCalls); } handler = NettyClientHandler.newHandler( - lifecycleManager, - keepAliveManager, - autoFlowControl, - flowControlWindow, - maxHeaderListSize, - GrpcUtil.STOPWATCH_SUPPLIER, - tooManyPingsRunnable, - transportTracer, - eagAttributes, - authorityString, - channelLogger, - ticker); + lifecycleManager, + keepAliveManager, + autoFlowControl, + flowControlWindow, + maxHeaderListSize, + softLimitHeaderListSize, + GrpcUtil.STOPWATCH_SUPPLIER, + tooManyPingsRunnable, + transportTracer, + eagAttributes, + authorityString, + channelLogger, + ticker, + metricRecorder); ChannelHandler negotiationHandler = negotiator.newHandler(handler); @@ -241,13 +266,6 @@ public Runnable start(Listener transportListener) { b.channelFactory(channelFactory); // For non-socket based channel, the option will be ignored. b.option(SO_KEEPALIVE, true); - // For non-epoll based channel, the option will be ignored. - if (keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED) { - ChannelOption tcpUserTimeout = Utils.maybeGetTcpUserTimeoutOption(); - if (tcpUserTimeout != null) { - b.option(tcpUserTimeout, (int) TimeUnit.NANOSECONDS.toMillis(keepAliveTimeoutNanos)); - } - } for (Map.Entry, ?> entry : channelOptions.entrySet()) { // Every entry in the map is obtained from // NettyChannelBuilder#withOption(ChannelOption option, T value) @@ -281,11 +299,26 @@ public void run() { // could use GlobalEventExecutor (which is what regFuture would use for notifying // listeners in this case), but avoiding on-demand thread creation in an error case seems // a good idea and is probably clearer threading. - lifecycleManager.notifyTerminated(statusExplainingWhyTheChannelIsNull); + lifecycleManager.notifyTerminated(statusExplainingWhyTheChannelIsNull, + SimpleDisconnectError.UNKNOWN); } }; } channel = regFuture.channel(); + // For non-epoll based channel, the option will be ignored. + try { + if (keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED + && Class.forName("io.netty.channel.epoll.AbstractEpollChannel").isInstance(channel)) { + ChannelOption tcpUserTimeout = Utils.maybeGetTcpUserTimeoutOption(); + if (tcpUserTimeout != null) { + int tcpUserTimeoutMs = (int) TimeUnit.NANOSECONDS.toMillis(keepAliveTimeoutNanos); + channel.config().setOption(tcpUserTimeout, tcpUserTimeoutMs); + } + } + } catch (ClassNotFoundException ignored) { + // JVM did not load AbstractEpollChannel, so the current channel will not be of epoll type, + // so there is no need to set TCP_USER_TIMEOUT + } // Start the write queue as soon as the channel is constructed handler.startWriteQueue(channel); // This write will have no effect, yet it will only complete once the negotiationHandler @@ -299,7 +332,8 @@ public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { // Need to notify of this failure, because NettyClientHandler may not have been added to // the pipeline before the error occurred. - lifecycleManager.notifyTerminated(Utils.statusFromThrowable(future.cause())); + lifecycleManager.notifyTerminated(Utils.statusFromThrowable(future.cause()), + SimpleDisconnectError.UNKNOWN); } } }); @@ -333,12 +367,17 @@ public void shutdown(Status reason) { @Override public void shutdownNow(final Status reason) { + shutdownNow(reason, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + } + + @Override + public void shutdownNow(final Status reason, DisconnectError disconnectError) { // Notifying of termination is automatically done when the channel closes. if (channel != null && channel.isOpen()) { handler.getWriteQueue().enqueue(new Runnable() { @Override public void run() { - lifecycleManager.notifyShutdown(reason); + lifecycleManager.notifyShutdown(reason, disconnectError); channel.write(new ForcefulCloseCommand(reason)); } }, true); diff --git a/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java b/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java index 7e180544de4..af5ec8d8bad 100644 --- a/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java +++ b/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java @@ -60,11 +60,6 @@ public void readBytes(byte[] dest, int index, int length) { buffer.readBytes(dest, index, length); } - @Override - public void readBytes(ByteBuffer dest) { - buffer.readBytes(dest); - } - @Override public void readBytes(OutputStream dest, int length) { try { diff --git a/netty/src/main/java/io/grpc/netty/NettyServer.java b/netty/src/main/java/io/grpc/netty/NettyServer.java index 2960604e5b5..2bb6b2c5921 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServer.java +++ b/netty/src/main/java/io/grpc/netty/NettyServer.java @@ -31,6 +31,7 @@ import io.grpc.InternalInstrumented; import io.grpc.InternalLogId; import io.grpc.InternalWithLogId; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.internal.InternalServer; import io.grpc.internal.ObjectPool; @@ -92,6 +93,8 @@ class NettyServer implements InternalServer, InternalWithLogId { private final int flowControlWindow; private final int maxMessageSize; private final int maxHeaderListSize; + private final int softLimitHeaderListSize; + private MetricRecorder metricRecorder; private final long keepAliveTimeInNanos; private final long keepAliveTimeoutInNanos; private final long maxConnectionIdleInNanos; @@ -123,15 +126,22 @@ class NettyServer implements InternalServer, InternalWithLogId { ProtocolNegotiator protocolNegotiator, List streamTracerFactories, TransportTracer.Factory transportTracerFactory, - int maxStreamsPerConnection, boolean autoFlowControl, int flowControlWindow, - int maxMessageSize, int maxHeaderListSize, - long keepAliveTimeInNanos, long keepAliveTimeoutInNanos, + int maxStreamsPerConnection, + boolean autoFlowControl, + int flowControlWindow, + int maxMessageSize, + int maxHeaderListSize, + int softLimitHeaderListSize, + long keepAliveTimeInNanos, + long keepAliveTimeoutInNanos, long maxConnectionIdleInNanos, long maxConnectionAgeInNanos, long maxConnectionAgeGraceInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, int maxRstCount, long maxRstPeriodNanos, - Attributes eagAttributes, InternalChannelz channelz) { + Attributes eagAttributes, InternalChannelz channelz, + MetricRecorder metricRecorder) { this.addresses = checkNotNull(addresses, "addresses"); + this.metricRecorder = metricRecorder; this.channelFactory = checkNotNull(channelFactory, "channelFactory"); checkNotNull(channelOptions, "channelOptions"); this.channelOptions = new HashMap, Object>(channelOptions); @@ -152,6 +162,7 @@ class NettyServer implements InternalServer, InternalWithLogId { this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; + this.softLimitHeaderListSize = softLimitHeaderListSize; this.keepAliveTimeInNanos = keepAliveTimeInNanos; this.keepAliveTimeoutInNanos = keepAliveTimeoutInNanos; this.maxConnectionIdleInNanos = maxConnectionIdleInNanos; @@ -167,6 +178,7 @@ class NettyServer implements InternalServer, InternalWithLogId { String.valueOf(addresses)); } + @Override public SocketAddress getListenSocketAddress() { Iterator it = channelGroup.iterator(); @@ -243,28 +255,30 @@ public void initChannel(Channel ch) { (long) ((.9D + Math.random() * .2D) * maxConnectionAgeInNanos); } - NettyServerTransport transport = - new NettyServerTransport( - ch, - channelDone, - protocolNegotiator, - streamTracerFactories, - transportTracerFactory.create(), - maxStreamsPerConnection, - autoFlowControl, - flowControlWindow, - maxMessageSize, - maxHeaderListSize, - keepAliveTimeInNanos, - keepAliveTimeoutInNanos, - maxConnectionIdleInNanos, - maxConnectionAgeInNanos, - maxConnectionAgeGraceInNanos, - permitKeepAliveWithoutCalls, - permitKeepAliveTimeInNanos, - maxRstCount, - maxRstPeriodNanos, - eagAttributes); + NettyServerTransport transport = + new NettyServerTransport( + ch, + channelDone, + protocolNegotiator, + streamTracerFactories, + transportTracerFactory.create(), + maxStreamsPerConnection, + autoFlowControl, + flowControlWindow, + maxMessageSize, + maxHeaderListSize, + softLimitHeaderListSize, + keepAliveTimeInNanos, + keepAliveTimeoutInNanos, + maxConnectionIdleInNanos, + maxConnectionAgeInNanos, + maxConnectionAgeGraceInNanos, + permitKeepAliveWithoutCalls, + permitKeepAliveTimeInNanos, + maxRstCount, + maxRstPeriodNanos, + eagAttributes, + metricRecorder); ServerTransportListener transportListener; // This is to order callbacks on the listener, not to guard access to channel. synchronized (NettyServer.this) { diff --git a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java index 3b82b193f61..3c9d2bbe184 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java @@ -32,6 +32,7 @@ import io.grpc.ExperimentalApi; import io.grpc.ForwardingServerBuilder; import io.grpc.Internal; +import io.grpc.MetricRecorder; import io.grpc.ServerBuilder; import io.grpc.ServerCredentials; import io.grpc.ServerStreamTracer; @@ -105,6 +106,7 @@ public final class NettyServerBuilder extends ForwardingServerBuilder streamTracerFactories) { - return buildTransportServers(streamTracerFactories); + List streamTracerFactories, + MetricRecorder metricRecorder) { + return buildTransportServers(streamTracerFactories, metricRecorder); } } @@ -492,6 +495,39 @@ public NettyServerBuilder maxHeaderListSize(int maxHeaderListSize) { public NettyServerBuilder maxInboundMetadataSize(int bytes) { checkArgument(bytes > 0, "maxInboundMetadataSize must be positive: %s", bytes); this.maxHeaderListSize = bytes; + // Clear the soft limit setting, by setting soft limit to maxInboundMetadataSize. The + // maxInboundMetadataSize will take precedence over soft limit check. + this.softLimitHeaderListSize = bytes; + return this; + } + + /** + * Sets the size of metadata that clients are advised to not exceed. When a metadata with size + * larger than the soft limit is encountered there will be a probability the RPC will fail. The + * chance of failing increases as the metadata size approaches the hard limit. + * {@code Integer.MAX_VALUE} disables the enforcement. The default is implementation-dependent, + * but is not generally less than 8 KiB and may be unlimited. + * + *

This is cumulative size of the metadata. The precise calculation is + * implementation-dependent, but implementations are encouraged to follow the calculation used + * for + * HTTP/2's + * SETTINGS_MAX_HEADER_LIST_SIZE. It sums the bytes from each entry's key and value, plus 32 + * bytes of overhead per entry. + * + * @param soft the soft size limit of received metadata + * @param max the hard size limit of received metadata + * @return this + * @throws IllegalArgumentException if soft and/or max is non-positive, or max smaller than soft + * @since 1.68.0 + */ + @CanIgnoreReturnValue + public NettyServerBuilder maxInboundMetadataSize(int soft, int max) { + checkArgument(soft > 0, "softLimitHeaderListSize must be positive: %s", soft); + checkArgument(max > soft, + "maxInboundMetadataSize: %s must be greater than softLimitHeaderListSize: %s", max, soft); + this.softLimitHeaderListSize = soft; + this.maxHeaderListSize = max; return this; } @@ -669,22 +705,44 @@ void eagAttributes(Attributes eagAttributes) { this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); } + @VisibleForTesting NettyServer buildTransportServers( - List streamTracerFactories) { + List streamTracerFactories, + MetricRecorder metricRecorder) { assertEventLoopsAndChannelType(); ProtocolNegotiator negotiator = protocolNegotiatorFactory.newNegotiator( this.serverImplBuilder.getExecutorPool()); return new NettyServer( - listenAddresses, channelFactory, channelOptions, childChannelOptions, - bossEventLoopGroupPool, workerEventLoopGroupPool, forceHeapBuffer, negotiator, - streamTracerFactories, transportTracerFactory, maxConcurrentCallsPerConnection, - autoFlowControl, flowControlWindow, maxMessageSize, maxHeaderListSize, - keepAliveTimeInNanos, keepAliveTimeoutInNanos, - maxConnectionIdleInNanos, maxConnectionAgeInNanos, - maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, - maxRstCount, maxRstPeriodNanos, eagAttributes, this.serverImplBuilder.getChannelz()); + listenAddresses, + channelFactory, + channelOptions, + childChannelOptions, + bossEventLoopGroupPool, + workerEventLoopGroupPool, + forceHeapBuffer, + negotiator, + streamTracerFactories, + transportTracerFactory, + maxConcurrentCallsPerConnection, + autoFlowControl, + flowControlWindow, + maxMessageSize, + maxHeaderListSize, + softLimitHeaderListSize, + keepAliveTimeInNanos, + keepAliveTimeoutInNanos, + maxConnectionIdleInNanos, + maxConnectionAgeInNanos, + maxConnectionAgeGraceInNanos, + permitKeepAliveWithoutCalls, + permitKeepAliveTimeInNanos, + maxRstCount, + maxRstPeriodNanos, + eagAttributes, + this.serverImplBuilder.getChannelz(), + metricRecorder); } @VisibleForTesting diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index a6e855a199d..79715ca2996 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -42,6 +42,7 @@ import io.grpc.InternalMetadata; import io.grpc.InternalStatus; import io.grpc.Metadata; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.internal.GrpcUtil; @@ -60,6 +61,8 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoder; import io.netty.handler.codec.http2.DecoratingHttp2FrameWriter; import io.netty.handler.codec.http2.DefaultHttp2Connection; import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; @@ -67,8 +70,10 @@ import io.netty.handler.codec.http2.DefaultHttp2FrameReader; import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.DefaultHttp2HeadersEncoder; import io.netty.handler.codec.http2.DefaultHttp2LocalFlowController; import io.netty.handler.codec.http2.DefaultHttp2RemoteFlowController; +import io.netty.handler.codec.http2.EmptyHttp2Headers; import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2ConnectionAdapter; import io.netty.handler.codec.http2.Http2ConnectionDecoder; @@ -82,12 +87,14 @@ import io.netty.handler.codec.http2.Http2FrameWriter; import io.netty.handler.codec.http2.Http2Headers; import io.netty.handler.codec.http2.Http2HeadersDecoder; +import io.netty.handler.codec.http2.Http2HeadersEncoder; import io.netty.handler.codec.http2.Http2InboundFrameLogger; +import io.netty.handler.codec.http2.Http2LifecycleManager; import io.netty.handler.codec.http2.Http2OutboundFrameLogger; import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.codec.http2.Http2Stream; import io.netty.handler.codec.http2.Http2StreamVisitor; -import io.netty.handler.codec.http2.WeightedFairQueueByteDistributor; +import io.netty.handler.codec.http2.UniformStreamByteDistributor; import io.netty.handler.logging.LogLevel; import io.netty.util.AsciiString; import io.netty.util.ReferenceCountUtil; @@ -121,17 +128,16 @@ class NettyServerHandler extends AbstractNettyHandler { private final Http2Connection.PropertyKey streamKey; private final ServerTransportListener transportListener; private final int maxMessageSize; + private final TcpMetrics tcpMetrics; private final long keepAliveTimeInNanos; private final long keepAliveTimeoutInNanos; private final long maxConnectionAgeInNanos; private final long maxConnectionAgeGraceInNanos; - private final int maxRstCount; - private final long maxRstPeriodNanos; + private final RstStreamCounter rstStreamCounter; private final List streamTracerFactories; private final TransportTracer transportTracer; private final KeepAliveEnforcer keepAliveEnforcer; private final Attributes eagAttributes; - private final Ticker ticker; /** Incomplete attributes produced by negotiator. */ private Attributes negotiationAttributes; private InternalChannelz.Security securityInfo; @@ -149,9 +155,6 @@ class NettyServerHandler extends AbstractNettyHandler { private ScheduledFuture maxConnectionAgeMonitor; @CheckForNull private GracefulShutdown gracefulShutdown; - private int rstCount; - private long lastRstNanoTime; - static NettyServerHandler newHandler( ServerTransportListener transportListener, @@ -162,6 +165,7 @@ static NettyServerHandler newHandler( boolean autoFlowControl, int flowControlWindow, int maxHeaderListSize, + int softLimitHeaderListSize, int maxMessageSize, long keepAliveTimeInNanos, long keepAliveTimeoutInNanos, @@ -172,15 +176,18 @@ static NettyServerHandler newHandler( long permitKeepAliveTimeInNanos, int maxRstCount, long maxRstPeriodNanos, - Attributes eagAttributes) { + Attributes eagAttributes, + MetricRecorder metricRecorder) { Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive: %s", maxHeaderListSize); Http2FrameLogger frameLogger = new Http2FrameLogger(LogLevel.DEBUG, NettyServerHandler.class); Http2HeadersDecoder headersDecoder = new GrpcHttp2ServerHeadersDecoder(maxHeaderListSize); Http2FrameReader frameReader = new Http2InboundFrameLogger( new DefaultHttp2FrameReader(headersDecoder), frameLogger); + Http2HeadersEncoder encoder = new DefaultHttp2HeadersEncoder( + Http2HeadersEncoder.NEVER_SENSITIVE, false, 16, Integer.MAX_VALUE); Http2FrameWriter frameWriter = - new Http2OutboundFrameLogger(new DefaultHttp2FrameWriter(), frameLogger); + new Http2OutboundFrameLogger(new DefaultHttp2FrameWriter(encoder), frameLogger); return newHandler( channelUnused, frameReader, @@ -192,6 +199,7 @@ static NettyServerHandler newHandler( autoFlowControl, flowControlWindow, maxHeaderListSize, + softLimitHeaderListSize, maxMessageSize, keepAliveTimeInNanos, keepAliveTimeoutInNanos, @@ -203,7 +211,8 @@ static NettyServerHandler newHandler( maxRstCount, maxRstPeriodNanos, eagAttributes, - Ticker.systemTicker()); + Ticker.systemTicker(), + metricRecorder); } static NettyServerHandler newHandler( @@ -217,6 +226,7 @@ static NettyServerHandler newHandler( boolean autoFlowControl, int flowControlWindow, int maxHeaderListSize, + int softLimitHeaderListSize, int maxMessageSize, long keepAliveTimeInNanos, long keepAliveTimeoutInNanos, @@ -228,24 +238,34 @@ static NettyServerHandler newHandler( int maxRstCount, long maxRstPeriodNanos, Attributes eagAttributes, - Ticker ticker) { + Ticker ticker, + MetricRecorder metricRecorder) { Preconditions.checkArgument(maxStreams > 0, "maxStreams must be positive: %s", maxStreams); Preconditions.checkArgument(flowControlWindow > 0, "flowControlWindow must be positive: %s", flowControlWindow); Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive: %s", maxHeaderListSize); + Preconditions.checkArgument( + softLimitHeaderListSize > 0, "softLimitHeaderListSize must be positive: %s", + softLimitHeaderListSize); Preconditions.checkArgument(maxMessageSize > 0, "maxMessageSize must be positive: %s", maxMessageSize); final Http2Connection connection = new DefaultHttp2Connection(true); - WeightedFairQueueByteDistributor dist = new WeightedFairQueueByteDistributor(connection); - dist.allocationQuantum(16 * 1024); // Make benchmarks fast again. + UniformStreamByteDistributor dist = new UniformStreamByteDistributor(connection); + dist.minAllocationChunk(MIN_ALLOCATED_CHUNK); // Increased for benchmarks performance. DefaultHttp2RemoteFlowController controller = new DefaultHttp2RemoteFlowController(connection, dist); connection.remote().flowController(controller); final KeepAliveEnforcer keepAliveEnforcer = new KeepAliveEnforcer( permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, TimeUnit.NANOSECONDS); + if (ticker == null) { + ticker = Ticker.systemTicker(); + } + + RstStreamCounter rstStreamCounter + = new RstStreamCounter(maxRstCount, maxRstPeriodNanos, ticker); // Create the local flow controller configured to auto-refill the connection window. connection.local().flowController( new DefaultHttp2LocalFlowController(connection, DEFAULT_WINDOW_UPDATE_RATIO, true)); @@ -253,6 +273,7 @@ static NettyServerHandler newHandler( Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter); encoder = new Http2ControlFrameLimitEncoder(encoder, 10000); + encoder = new Http2RstCounterEncoder(encoder, rstStreamCounter); Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader); @@ -261,10 +282,6 @@ static NettyServerHandler newHandler( settings.maxConcurrentStreams(maxStreams); settings.maxHeaderListSize(maxHeaderListSize); - if (ticker == null) { - ticker = Ticker.systemTicker(); - } - return new NettyServerHandler( channelUnused, connection, @@ -273,14 +290,17 @@ static NettyServerHandler newHandler( transportTracer, decoder, encoder, settings, maxMessageSize, - keepAliveTimeInNanos, keepAliveTimeoutInNanos, + maxHeaderListSize, + softLimitHeaderListSize, + keepAliveTimeInNanos, + keepAliveTimeoutInNanos, maxConnectionIdleInNanos, maxConnectionAgeInNanos, maxConnectionAgeGraceInNanos, keepAliveEnforcer, autoFlowControl, - maxRstCount, - maxRstPeriodNanos, - eagAttributes, ticker); + rstStreamCounter, + eagAttributes, ticker, + metricRecorder); } private NettyServerHandler( @@ -293,6 +313,8 @@ private NettyServerHandler( Http2ConnectionEncoder encoder, Http2Settings settings, int maxMessageSize, + int maxHeaderListSize, + int softLimitHeaderListSize, long keepAliveTimeInNanos, long keepAliveTimeoutInNanos, long maxConnectionIdleInNanos, @@ -300,12 +322,21 @@ private NettyServerHandler( long maxConnectionAgeGraceInNanos, final KeepAliveEnforcer keepAliveEnforcer, boolean autoFlowControl, - int maxRstCount, - long maxRstPeriodNanos, + RstStreamCounter rstStreamCounter, Attributes eagAttributes, - Ticker ticker) { - super(channelUnused, decoder, encoder, settings, new ServerChannelLogger(), - autoFlowControl, null, ticker); + Ticker ticker, + MetricRecorder metricRecorder) { + super( + channelUnused, + decoder, + encoder, + settings, + new ServerChannelLogger(), + autoFlowControl, + null, + ticker, + maxHeaderListSize, + softLimitHeaderListSize); final MaxConnectionIdleManager maxConnectionIdleManager; if (maxConnectionIdleInNanos == MAX_CONNECTION_IDLE_NANOS_DISABLED) { @@ -338,18 +369,16 @@ public void onStreamClosed(Http2Stream stream) { checkArgument(maxMessageSize >= 0, "maxMessageSize must be non-negative: %s", maxMessageSize); this.maxMessageSize = maxMessageSize; + this.tcpMetrics = new TcpMetrics(metricRecorder); this.keepAliveTimeInNanos = keepAliveTimeInNanos; this.keepAliveTimeoutInNanos = keepAliveTimeoutInNanos; this.maxConnectionIdleManager = maxConnectionIdleManager; this.maxConnectionAgeInNanos = maxConnectionAgeInNanos; this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos; this.keepAliveEnforcer = checkNotNull(keepAliveEnforcer, "keepAliveEnforcer"); - this.maxRstCount = maxRstCount; - this.maxRstPeriodNanos = maxRstPeriodNanos; + this.rstStreamCounter = rstStreamCounter; this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); - this.ticker = checkNotNull(ticker, "ticker"); - this.lastRstNanoTime = ticker.read(); streamKey = encoder.connection().newKey(); this.transportListener = checkNotNull(transportListener, "transportListener"); this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories"); @@ -465,8 +494,20 @@ private void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers } if (!HTTP_METHOD.contentEquals(headers.method())) { + Http2Headers extraHeaders = new DefaultHttp2Headers(); + extraHeaders.add(HttpHeaderNames.ALLOW, HTTP_METHOD); respondWithHttpError(ctx, streamId, 405, Status.Code.INTERNAL, - String.format("Method '%s' is not supported", headers.method())); + String.format("Method '%s' is not supported", headers.method()), extraHeaders); + return; + } + + int h2HeadersSize = Utils.getH2HeadersSize(headers); + if (Utils.shouldRejectOnMetadataSizeSoftLimitExceeded( + h2HeadersSize, softLimitHeaderListSize, maxHeaderListSize)) { + respondWithHttpError(ctx, streamId, 431, Status.Code.RESOURCE_EXHAUSTED, String.format( + "Client Headers of size %d exceeded Metadata size soft limit: %d", + h2HeadersSize, + softLimitHeaderListSize)); return; } @@ -546,24 +587,9 @@ private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfSt } private void onRstStreamRead(int streamId, long errorCode) throws Http2Exception { - if (maxRstCount > 0) { - long now = ticker.read(); - if (now - lastRstNanoTime > maxRstPeriodNanos) { - lastRstNanoTime = now; - rstCount = 1; - } else { - rstCount++; - if (rstCount > maxRstCount) { - throw new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") { - @SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses - @Override - public Throwable fillInStackTrace() { - // Avoid the CPU cycles, since the resets may be a CPU consumption attack - return this; - } - }; - } - } + Http2Exception tooManyRstStream = rstStreamCounter.countRstStream(); + if (tooManyRstStream != null) { + throw tooManyRstStream; } try { @@ -643,9 +669,16 @@ void setKeepAliveManagerForTest(KeepAliveManager keepAliveManager) { /** * Handler for the Channel shutting down. */ + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + tcpMetrics.channelActive(ctx.channel()); + super.channelActive(ctx); + } + @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { try { + tcpMetrics.channelInactive(ctx.channel()); if (keepAliveManager != null) { keepAliveManager.onTransportTermination(); } @@ -859,6 +892,12 @@ public boolean visit(Http2Stream stream) throws Http2Exception { private void respondWithHttpError( ChannelHandlerContext ctx, int streamId, int code, Status.Code statusCode, String msg) { + respondWithHttpError(ctx, streamId, code, statusCode, msg, EmptyHttp2Headers.INSTANCE); + } + + private void respondWithHttpError( + ChannelHandlerContext ctx, int streamId, int code, Status.Code statusCode, String msg, + Http2Headers extraHeaders) { Metadata metadata = new Metadata(); metadata.put(InternalStatus.CODE_KEY, statusCode.toStatus()); metadata.put(InternalStatus.MESSAGE_KEY, msg); @@ -870,6 +909,7 @@ private void respondWithHttpError( for (int i = 0; i < serialized.length; i += 2) { headers.add(new AsciiString(serialized[i], false), new AsciiString(serialized[i + 1], false)); } + headers.add(extraHeaders); encoder().writeHeaders(ctx, streamId, headers, 0, false, ctx.newPromise()); ByteBuf msgBuf = ByteBufUtil.writeUtf8(ctx.alloc(), msg); encoder().writeData(ctx, streamId, msgBuf, 0, true, ctx.newPromise()); @@ -1151,6 +1191,81 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2 } } + private static final class Http2RstCounterEncoder extends DecoratingHttp2ConnectionEncoder { + private final RstStreamCounter rstStreamCounter; + private Http2LifecycleManager lifecycleManager; + + Http2RstCounterEncoder(Http2ConnectionEncoder encoder, RstStreamCounter rstStreamCounter) { + super(encoder); + this.rstStreamCounter = rstStreamCounter; + } + + @Override + public void lifecycleManager(Http2LifecycleManager lifecycleManager) { + this.lifecycleManager = lifecycleManager; + super.lifecycleManager(lifecycleManager); + } + + @Override + public ChannelFuture writeRstStream( + ChannelHandlerContext ctx, int streamId, long errorCode, ChannelPromise promise) { + ChannelFuture future = super.writeRstStream(ctx, streamId, errorCode, promise); + // We want to count "induced" RST_STREAM, where the server sent a reset because of a malformed + // frame. + boolean normalRst + = errorCode == Http2Error.NO_ERROR.code() || errorCode == Http2Error.CANCEL.code(); + if (!normalRst) { + Http2Exception tooManyRstStream = rstStreamCounter.countRstStream(); + if (tooManyRstStream != null) { + lifecycleManager.onError(ctx, true, tooManyRstStream); + ctx.close(); + } + } + return future; + } + } + + private static final class RstStreamCounter { + private final int maxRstCount; + private final long maxRstPeriodNanos; + private final Ticker ticker; + private int rstCount; + private long lastRstNanoTime; + + RstStreamCounter(int maxRstCount, long maxRstPeriodNanos, Ticker ticker) { + checkArgument(maxRstCount >= 0, "maxRstCount must be non-negative: %s", maxRstCount); + this.maxRstCount = maxRstCount; + this.maxRstPeriodNanos = maxRstPeriodNanos; + this.ticker = checkNotNull(ticker, "ticker"); + this.lastRstNanoTime = ticker.read(); + } + + /** Returns non-{@code null} when the connection should be killed by the caller. */ + private Http2Exception countRstStream() { + if (maxRstCount == 0) { + return null; + } + long now = ticker.read(); + if (now - lastRstNanoTime > maxRstPeriodNanos) { + lastRstNanoTime = now; + rstCount = 1; + } else { + rstCount++; + if (rstCount > maxRstCount) { + return new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") { + @SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses + @Override + public Throwable fillInStackTrace() { + // Avoid the CPU cycles, since the resets may be a CPU consumption attack + return this; + } + }; + } + } + return null; + } + } + private static class ServerChannelLogger extends ChannelLogger { private static final Logger log = Logger.getLogger(ChannelLogger.class.getName()); diff --git a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java index 9511927a09f..c0e52b75876 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java @@ -25,6 +25,7 @@ import io.grpc.Attributes; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.internal.ServerTransport; @@ -70,6 +71,7 @@ class NettyServerTransport implements ServerTransport { private final int flowControlWindow; private final int maxMessageSize; private final int maxHeaderListSize; + private final int softLimitHeaderListSize; private final long keepAliveTimeInNanos; private final long keepAliveTimeoutInNanos; private final long maxConnectionIdleInNanos; @@ -80,6 +82,7 @@ class NettyServerTransport implements ServerTransport { private final int maxRstCount; private final long maxRstPeriodNanos; private final Attributes eagAttributes; + private final MetricRecorder metricRecorder; private final List streamTracerFactories; private final TransportTracer transportTracer; @@ -94,6 +97,7 @@ class NettyServerTransport implements ServerTransport { int flowControlWindow, int maxMessageSize, int maxHeaderListSize, + int softLimitHeaderListSize, long keepAliveTimeInNanos, long keepAliveTimeoutInNanos, long maxConnectionIdleInNanos, @@ -103,7 +107,8 @@ class NettyServerTransport implements ServerTransport { long permitKeepAliveTimeInNanos, int maxRstCount, long maxRstPeriodNanos, - Attributes eagAttributes) { + Attributes eagAttributes, + MetricRecorder metricRecorder) { this.channel = Preconditions.checkNotNull(channel, "channel"); this.channelUnused = channelUnused; this.protocolNegotiator = Preconditions.checkNotNull(protocolNegotiator, "protocolNegotiator"); @@ -115,6 +120,7 @@ class NettyServerTransport implements ServerTransport { this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; + this.softLimitHeaderListSize = softLimitHeaderListSize; this.keepAliveTimeInNanos = keepAliveTimeInNanos; this.keepAliveTimeoutInNanos = keepAliveTimeoutInNanos; this.maxConnectionIdleInNanos = maxConnectionIdleInNanos; @@ -125,6 +131,7 @@ class NettyServerTransport implements ServerTransport { this.maxRstCount = maxRstCount; this.maxRstPeriodNanos = maxRstPeriodNanos; this.eagAttributes = Preconditions.checkNotNull(eagAttributes, "eagAttributes"); + this.metricRecorder = metricRecorder; SocketAddress remote = channel.remoteAddress(); this.logId = InternalLogId.allocate(getClass(), remote != null ? remote.toString() : null); } @@ -275,6 +282,7 @@ private NettyServerHandler createHandler( autoFlowControl, flowControlWindow, maxHeaderListSize, + softLimitHeaderListSize, maxMessageSize, keepAliveTimeInNanos, keepAliveTimeoutInNanos, @@ -285,6 +293,7 @@ private NettyServerHandler createHandler( permitKeepAliveTimeInNanos, maxRstCount, maxRstPeriodNanos, - eagAttributes); + eagAttributes, + metricRecorder); } } diff --git a/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java b/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java index ede511b68f6..3d3fdc67e8e 100644 --- a/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java +++ b/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java @@ -34,6 +34,6 @@ public static ChannelCredentials create(SslContext sslContext) { Preconditions.checkArgument(sslContext.isClient(), "Server SSL context can not be used for client channel"); GrpcSslContexts.ensureAlpnAndH2Enabled(sslContext.applicationProtocolNegotiator()); - return NettyChannelCredentials.create(ProtocolNegotiators.tlsClientFactory(sslContext)); + return NettyChannelCredentials.create(ProtocolNegotiators.tlsClientFactory(sslContext, null)); } } diff --git a/netty/src/main/java/io/grpc/netty/NettyWritableBufferAllocator.java b/netty/src/main/java/io/grpc/netty/NettyWritableBufferAllocator.java index 9e93ee1155c..40b84717160 100644 --- a/netty/src/main/java/io/grpc/netty/NettyWritableBufferAllocator.java +++ b/netty/src/main/java/io/grpc/netty/NettyWritableBufferAllocator.java @@ -33,9 +33,6 @@ */ class NettyWritableBufferAllocator implements WritableBufferAllocator { - // Use 4k as our minimum buffer size. - private static final int MIN_BUFFER = 4 * 1024; - // Set the maximum buffer size to 1MB. private static final int MAX_BUFFER = 1024 * 1024; @@ -47,7 +44,7 @@ class NettyWritableBufferAllocator implements WritableBufferAllocator { @Override public WritableBuffer allocate(int capacityHint) { - capacityHint = Math.min(MAX_BUFFER, Math.max(MIN_BUFFER, capacityHint)); + capacityHint = Math.min(MAX_BUFFER, capacityHint); return new NettyWritableBuffer(allocator.buffer(capacityHint, capacityHint)); } } diff --git a/netty/src/main/java/io/grpc/netty/NoopSslEngine.java b/netty/src/main/java/io/grpc/netty/NoopSslEngine.java new file mode 100644 index 00000000000..7e14dbf0e79 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/NoopSslEngine.java @@ -0,0 +1,151 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import java.nio.ByteBuffer; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; + +/** + * A no-op implementation of SslEngine, to facilitate overriding only the required methods in + * specific implementations. + */ +class NoopSslEngine extends SSLEngine { + @Override + public SSLEngineResult wrap(ByteBuffer[] srcs, int offset, int length, ByteBuffer dst) + throws SSLException { + return null; + } + + @Override + public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts, int offset, int length) + throws SSLException { + return null; + } + + @Override + public Runnable getDelegatedTask() { + return null; + } + + @Override + public void closeInbound() throws SSLException { + + } + + @Override + public boolean isInboundDone() { + return false; + } + + @Override + public void closeOutbound() { + + } + + @Override + public boolean isOutboundDone() { + return false; + } + + @Override + public String[] getSupportedCipherSuites() { + return new String[0]; + } + + @Override + public String[] getEnabledCipherSuites() { + return new String[0]; + } + + @Override + public void setEnabledCipherSuites(String[] suites) { + + } + + @Override + public String[] getSupportedProtocols() { + return new String[0]; + } + + @Override + public String[] getEnabledProtocols() { + return new String[0]; + } + + @Override + public void setEnabledProtocols(String[] protocols) { + + } + + @Override + public SSLSession getSession() { + return null; + } + + @Override + public void beginHandshake() throws SSLException { + + } + + @Override + public SSLEngineResult.HandshakeStatus getHandshakeStatus() { + return null; + } + + @Override + public void setUseClientMode(boolean mode) { + + } + + @Override + public boolean getUseClientMode() { + return false; + } + + @Override + public void setNeedClientAuth(boolean need) { + + } + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean want) { + + } + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean flag) { + + } + + @Override + public boolean getEnableSessionCreation() { + return false; + } +} diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiationEvent.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiationEvent.java index 16da79e1af8..8103a2dc79f 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiationEvent.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiationEvent.java @@ -20,10 +20,10 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Objects; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.Internal; import io.grpc.InternalChannelz.Security; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; /** diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java index 8a2c6f104b2..4332fdf2919 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java @@ -63,4 +63,5 @@ interface ServerFactory { */ ProtocolNegotiator newNegotiator(ObjectPool offloadExecutorPool); } + } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 80df5d0e3c7..8faf3d0fae8 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -16,10 +16,10 @@ package io.grpc.netty; -import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Optional; import com.google.common.base.Preconditions; import com.google.errorprone.annotations.ForOverride; import io.grpc.Attributes; @@ -41,18 +41,22 @@ import io.grpc.Status; import io.grpc.TlsChannelCredentials; import io.grpc.TlsServerCredentials; +import io.grpc.internal.CertificateUtils; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.NoopSslSession; import io.grpc.internal.ObjectPool; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpClientUpgradeHandler; import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http2.Http2ClientUpgradeCodec; @@ -70,8 +74,12 @@ import java.net.SocketAddress; import java.net.URI; import java.nio.channels.ClosedChannelException; +import java.security.GeneralSecurityException; +import java.security.KeyStore; import java.util.Arrays; import java.util.EnumSet; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; import java.util.logging.Level; @@ -81,6 +89,10 @@ import javax.net.ssl.SSLException; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** * Common {@link ProtocolNegotiator}s used by gRPC. @@ -94,7 +106,6 @@ final class ProtocolNegotiators { EnumSet.of( TlsServerCredentials.Feature.MTLS, TlsServerCredentials.Feature.CUSTOM_MANAGERS); - private ProtocolNegotiators() { } @@ -116,14 +127,25 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) { new ByteArrayInputStream(tlsCreds.getPrivateKey()), tlsCreds.getPrivateKeyPassword()); } - if (tlsCreds.getTrustManagers() != null) { - builder.trustManager(new FixedTrustManagerFactory(tlsCreds.getTrustManagers())); - } else if (tlsCreds.getRootCertificates() != null) { - builder.trustManager(new ByteArrayInputStream(tlsCreds.getRootCertificates())); - } // else use system default try { - return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build())); - } catch (SSLException ex) { + List trustManagers; + if (tlsCreds.getTrustManagers() != null) { + trustManagers = tlsCreds.getTrustManagers(); + } else if (tlsCreds.getRootCertificates() != null) { + trustManagers = Arrays.asList(CertificateUtils.createTrustManager( + new ByteArrayInputStream(tlsCreds.getRootCertificates()))); + } else { // else use system default + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init((KeyStore) null); + trustManagers = Arrays.asList(tmf.getTrustManagers()); + } + builder.trustManager(new FixedTrustManagerFactory(trustManagers)); + TrustManager x509ExtendedTrustManager = + CertificateUtils.getX509ExtendedTrustManager(trustManagers); + return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build(), + (X509TrustManager) x509ExtendedTrustManager)); + } catch (SSLException | GeneralSecurityException ex) { log.log(Level.FINE, "Exception building SslContext", ex); return FromChannelCredentialsResult.error( "Unable to create SslContext: " + ex.getMessage()); @@ -409,8 +431,8 @@ static final class ServerTlsHandler extends ChannelInboundHandlerAdapter { ServerTlsHandler(ChannelHandler next, SslContext sslContext, final ObjectPool executorPool) { - this.sslContext = checkNotNull(sslContext, "sslContext"); - this.next = checkNotNull(next, "next"); + this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); + this.next = Preconditions.checkNotNull(next, "next"); if (executorPool != null) { this.executor = executorPool.getObject(); } @@ -465,18 +487,20 @@ private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession * Returns a {@link ProtocolNegotiator} that does HTTP CONNECT proxy negotiation. */ public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress, - final @Nullable String proxyUsername, final @Nullable String proxyPassword, + final @Nullable Map headers, final @Nullable String proxyUsername, + final @Nullable String proxyPassword, final ProtocolNegotiator negotiator) { - checkNotNull(negotiator, "negotiator"); - checkNotNull(proxyAddress, "proxyAddress"); + Preconditions.checkNotNull(negotiator, "negotiator"); + Preconditions.checkNotNull(proxyAddress, "proxyAddress"); final AsciiString scheme = negotiator.scheme(); class ProxyNegotiator implements ProtocolNegotiator { @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler http2Handler) { ChannelHandler protocolNegotiationHandler = negotiator.newHandler(http2Handler); ChannelLogger negotiationLogger = http2Handler.getNegotiationLogger(); + HttpHeaders httpHeaders = toHttpHeaders(headers); return new ProxyProtocolNegotiationHandler( - proxyAddress, proxyUsername, proxyPassword, protocolNegotiationHandler, + proxyAddress, httpHeaders, proxyUsername, proxyPassword, protocolNegotiationHandler, negotiationLogger); } @@ -496,6 +520,22 @@ public void close() { return new ProxyNegotiator(); } + /** + * Converts generic Map of headers to Netty's HttpHeaders. + * Returns null if the map is null or empty. + */ + @Nullable + private static HttpHeaders toHttpHeaders(@Nullable Map headers) { + if (headers == null || headers.isEmpty()) { + return null; + } + HttpHeaders httpHeaders = new DefaultHttpHeaders(); + for (Map.Entry entry : headers.entrySet()) { + httpHeaders.add(entry.getKey(), entry.getValue()); + } + return httpHeaders; + } + /** * A Proxy handler follows {@link ProtocolNegotiationHandler} pattern. Upon successful proxy * connection, this handler will install {@code next} handler which should be a handler from @@ -504,17 +544,20 @@ public void close() { static final class ProxyProtocolNegotiationHandler extends ProtocolNegotiationHandler { private final SocketAddress address; + @Nullable private final HttpHeaders httpHeaders; @Nullable private final String userName; @Nullable private final String password; public ProxyProtocolNegotiationHandler( SocketAddress address, + @Nullable HttpHeaders httpHeaders, @Nullable String userName, @Nullable String password, ChannelHandler next, ChannelLogger negotiationLogger) { super(next, negotiationLogger); - this.address = checkNotNull(address, "address"); + this.address = Preconditions.checkNotNull(address, "address"); + this.httpHeaders = httpHeaders; this.userName = userName; this.password = password; } @@ -523,9 +566,9 @@ public ProxyProtocolNegotiationHandler( protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { HttpProxyHandler nettyProxyHandler; if (userName == null || password == null) { - nettyProxyHandler = new HttpProxyHandler(address); + nettyProxyHandler = new HttpProxyHandler(address, httpHeaders); } else { - nettyProxyHandler = new HttpProxyHandler(address, userName, password); + nettyProxyHandler = new HttpProxyHandler(address, userName, password, httpHeaders); } ctx.pipeline().addBefore(ctx.name(), /* name= */ null, nettyProxyHandler); } @@ -543,16 +586,23 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator { public ClientTlsProtocolNegotiator(SslContext sslContext, - ObjectPool executorPool) { - this.sslContext = checkNotNull(sslContext, "sslContext"); + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, String sni) { + this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); this.executorPool = executorPool; if (this.executorPool != null) { this.executor = this.executorPool.getObject(); } + this.handshakeCompleteRunnable = handshakeCompleteRunnable; + this.x509ExtendedTrustManager = x509ExtendedTrustManager; + this.sni = sni; } private final SslContext sslContext; private final ObjectPool executorPool; + private final Optional handshakeCompleteRunnable; + private final X509TrustManager x509ExtendedTrustManager; + private final String sni; private Executor executor; @Override @@ -564,8 +614,17 @@ public AsciiString scheme() { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler); ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger(); - ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(), - this.executor, negotiationLogger); + String authority; + if ("".equals(sni)) { + authority = null; + } else if (sni != null) { + authority = sni; + } else { + authority = grpcHandler.getAuthority(); + } + ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, + authority, this.executor, negotiationLogger, handshakeCompleteRunnable, this, + x509ExtendedTrustManager); return new WaitUntilActiveHandler(cth, negotiationLogger); } @@ -575,6 +634,11 @@ public void close() { this.executorPool.returnObject(this.executor); } } + + @VisibleForTesting + boolean hasX509ExtendedTrustManager() { + return x509ExtendedTrustManager != null; + } } static final class ClientTlsHandler extends ProtocolNegotiationHandler { @@ -583,20 +647,38 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { private final String host; private final int port; private Executor executor; + private final Optional handshakeCompleteRunnable; + private final X509TrustManager x509TrustManager; + private SSLEngine sslEngine; ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority, - Executor executor, ChannelLogger negotiationLogger) { + Executor executor, ChannelLogger negotiationLogger, + Optional handshakeCompleteRunnable, + ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, + X509TrustManager x509TrustManager) { super(next, negotiationLogger); - this.sslContext = checkNotNull(sslContext, "sslContext"); - HostPort hostPort = parseAuthority(authority); - this.host = hostPort.host; - this.port = hostPort.port; + this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); + if (authority != null) { + HostPort hostPort = parseAuthority(authority); + this.host = hostPort.host; + this.port = hostPort.port; + } else { + this.host = null; + this.port = 0; + } this.executor = executor; + this.handshakeCompleteRunnable = handshakeCompleteRunnable; + this.x509TrustManager = x509TrustManager; } @Override + @IgnoreJRERequirement protected void handlerAdded0(ChannelHandlerContext ctx) { - SSLEngine sslEngine = sslContext.newEngine(ctx.alloc(), host, port); + if (host != null) { + sslEngine = sslContext.newEngine(ctx.alloc(), host, port); + } else { + sslEngine = sslContext.newEngine(ctx.alloc()); + } SSLParameters sslParams = sslEngine.getSSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslEngine.setSSLParameters(sslParams); @@ -636,6 +718,9 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws } ctx.fireExceptionCaught(t); } + if (handshakeCompleteRunnable.isPresent()) { + handshakeCompleteRunnable.get().run(); + } } else { super.userEventTriggered0(ctx, evt); } @@ -647,8 +732,13 @@ private void propagateTlsComplete(ChannelHandlerContext ctx, SSLSession session) Attributes attrs = existingPne.getAttributes().toBuilder() .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY) .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session) + .set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, new X509AuthorityVerifier( + sslEngine, x509TrustManager)) .build(); replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security)); + if (handshakeCompleteRunnable.isPresent()) { + handshakeCompleteRunnable.get().run(); + } fireProtocolNegotiationEvent(ctx); } } @@ -680,11 +770,14 @@ static HostPort parseAuthority(String authority) { * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. + * * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool) { - return new ClientTlsProtocolNegotiator(sslContext, executorPool); + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, String sni) { + return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable, + x509ExtendedTrustManager, sni); } /** @@ -692,25 +785,30 @@ public static ProtocolNegotiator tls(SslContext sslContext, * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. */ - public static ProtocolNegotiator tls(SslContext sslContext) { - return tls(sslContext, null); + public static ProtocolNegotiator tls(SslContext sslContext, + X509TrustManager x509ExtendedTrustManager) { + return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager, null); } - public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) { - return new TlsProtocolNegotiatorClientFactory(sslContext); + public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext, + X509TrustManager x509ExtendedTrustManager) { + return new TlsProtocolNegotiatorClientFactory(sslContext, x509ExtendedTrustManager); } @VisibleForTesting static final class TlsProtocolNegotiatorClientFactory implements ProtocolNegotiator.ClientFactory { private final SslContext sslContext; + private final X509TrustManager x509ExtendedTrustManager; - public TlsProtocolNegotiatorClientFactory(SslContext sslContext) { + public TlsProtocolNegotiatorClientFactory(SslContext sslContext, + X509TrustManager x509ExtendedTrustManager) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); + this.x509ExtendedTrustManager = x509ExtendedTrustManager; } @Override public ProtocolNegotiator newNegotiator() { - return tls(sslContext); + return tls(sslContext, x509ExtendedTrustManager); } @Override public int getDefaultPort() { @@ -763,7 +861,9 @@ public AsciiString scheme() { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler upgradeHandler = new Http2UpgradeAndGrpcHandler(grpcHandler.getAuthority(), grpcHandler); - return new WaitUntilActiveHandler(upgradeHandler, grpcHandler.getNegotiationLogger()); + ChannelHandler plaintextHandler = + new PlaintextHandler(upgradeHandler, grpcHandler.getNegotiationLogger()); + return new WaitUntilActiveHandler(plaintextHandler, grpcHandler.getNegotiationLogger()); } @Override @@ -784,8 +884,8 @@ static final class Http2UpgradeAndGrpcHandler extends ChannelInboundHandlerAdapt private ProtocolNegotiationEvent pne; Http2UpgradeAndGrpcHandler(String authority, GrpcHttp2ConnectionHandler next) { - this.authority = checkNotNull(authority, "authority"); - this.next = checkNotNull(next, "next"); + this.authority = Preconditions.checkNotNull(authority, "authority"); + this.next = Preconditions.checkNotNull(next, "next"); this.negotiationLogger = next.getNegotiationLogger(); } @@ -829,9 +929,9 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } /** - * Returns a {@link ChannelHandler} that ensures that the {@code handler} is added to the - * pipeline writes to the {@link io.netty.channel.Channel} may happen immediately, even before it - * is active. + * Returns a {@link io.netty.channel.ChannelHandler} that ensures that the {@code handler} is + * added to the pipeline writes to the {@link io.netty.channel.Channel} may happen immediately, + * even before it is active. */ public static ProtocolNegotiator plaintext() { return new PlaintextProtocolNegotiator(); @@ -909,7 +1009,7 @@ static final class GrpcNegotiationHandler extends ChannelInboundHandlerAdapter { private final GrpcHttp2ConnectionHandler next; public GrpcNegotiationHandler(GrpcHttp2ConnectionHandler next) { - this.next = checkNotNull(next, "next"); + this.next = Preconditions.checkNotNull(next, "next"); } @Override @@ -960,7 +1060,9 @@ static final class PlaintextProtocolNegotiator implements ProtocolNegotiator { @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler grpcNegotiationHandler = new GrpcNegotiationHandler(grpcHandler); - ChannelHandler activeHandler = new WaitUntilActiveHandler(grpcNegotiationHandler, + ChannelHandler plaintextHandler = + new PlaintextHandler(grpcNegotiationHandler, grpcHandler.getNegotiationLogger()); + ChannelHandler activeHandler = new WaitUntilActiveHandler(plaintextHandler, grpcHandler.getNegotiationLogger()); return activeHandler; } @@ -974,6 +1076,22 @@ public AsciiString scheme() { } } + static final class PlaintextHandler extends ProtocolNegotiationHandler { + PlaintextHandler(ChannelHandler next, ChannelLogger negotiationLogger) { + super(next, negotiationLogger); + } + + @Override + protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { + ProtocolNegotiationEvent existingPne = getProtocolNegotiationEvent(); + Attributes attrs = existingPne.getAttributes().toBuilder() + .set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, (authority) -> Status.OK) + .build(); + replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs)); + fireProtocolNegotiationEvent(ctx); + } + } + /** * Waits for the channel to be active, and then installs the next Handler. Using this allows * subsequent handlers to assume the channel is active and ready to send. Additionally, this a @@ -1031,15 +1149,15 @@ static class ProtocolNegotiationHandler extends ChannelDuplexHandler { protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName, ChannelLogger negotiationLogger) { - this.next = checkNotNull(next, "next"); + this.next = Preconditions.checkNotNull(next, "next"); this.negotiatorName = negotiatorName; - this.negotiationLogger = checkNotNull(negotiationLogger, "negotiationLogger"); + this.negotiationLogger = Preconditions.checkNotNull(negotiationLogger, "negotiationLogger"); } protected ProtocolNegotiationHandler(ChannelHandler next, ChannelLogger negotiationLogger) { - this.next = checkNotNull(next, "next"); + this.next = Preconditions.checkNotNull(next, "next"); this.negotiatorName = getClass().getSimpleName().replace("Handler", ""); - this.negotiationLogger = checkNotNull(negotiationLogger, "negotiationLogger"); + this.negotiationLogger = Preconditions.checkNotNull(negotiationLogger, "negotiationLogger"); } @Override @@ -1080,7 +1198,7 @@ protected final ProtocolNegotiationEvent getProtocolNegotiationEvent() { protected final void replaceProtocolNegotiationEvent(ProtocolNegotiationEvent pne) { checkState(this.pne != null, "previous protocol negotiation event hasn't triggered"); - this.pne = checkNotNull(pne); + this.pne = Preconditions.checkNotNull(pne); } protected final void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) { @@ -1090,4 +1208,42 @@ protected final void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) { ctx.fireUserEventTriggered(pne); } } + + static final class SslEngineWrapper extends NoopSslEngine { + private final SSLEngine sslEngine; + private final String peerHost; + + SslEngineWrapper(SSLEngine sslEngine, String peerHost) { + this.sslEngine = sslEngine; + this.peerHost = peerHost; + } + + @Override + public String getPeerHost() { + return peerHost; + } + + @Override + public SSLSession getHandshakeSession() { + return new FakeSslSession(peerHost); + } + + @Override + public SSLParameters getSSLParameters() { + return sslEngine.getSSLParameters(); + } + } + + static final class FakeSslSession extends NoopSslSession { + private final String peerHost; + + FakeSslSession(String peerHost) { + this.peerHost = peerHost; + } + + @Override + public String getPeerHost() { + return peerHost; + } + } } diff --git a/netty/src/main/java/io/grpc/netty/TcpMetrics.java b/netty/src/main/java/io/grpc/netty/TcpMetrics.java new file mode 100644 index 00000000000..c5809a5677e --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/TcpMetrics.java @@ -0,0 +1,227 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.InternalTcpMetrics; +import io.grpc.MetricRecorder; +import io.netty.channel.Channel; +import io.netty.util.concurrent.ScheduledFuture; +import java.lang.reflect.Method; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Utility for collecting TCP metrics from Netty channels. + */ +final class TcpMetrics { + private static final Logger log = Logger.getLogger(TcpMetrics.class.getName()); + + static EpollInfo epollInfo = loadEpollInfo(); + + static final class EpollInfo { + final Class channelClass; + final java.lang.reflect.Constructor infoConstructor; + final Method tcpInfo; + final Method totalRetrans; + final Method retransmits; + final Method rtt; + + EpollInfo( + Class channelClass, + java.lang.reflect.Constructor infoConstructor, + Method tcpInfo, + Method totalRetrans, + Method retransmits, + Method rtt) { + this.channelClass = channelClass; + this.infoConstructor = infoConstructor; + this.tcpInfo = tcpInfo; + this.totalRetrans = totalRetrans; + this.retransmits = retransmits; + this.rtt = rtt; + } + } + + static EpollInfo loadEpollInfo() { + boolean epollAvailable = false; + try { + Class epollClass = Class.forName("io.netty.channel.epoll.Epoll"); + Method isAvailableMethod = epollClass.getDeclaredMethod("isAvailable"); + epollAvailable = (Boolean) isAvailableMethod.invoke(null); + if (epollAvailable) { + Class channelClass = Class.forName("io.netty.channel.epoll.EpollSocketChannel"); + Class infoClass = Class.forName("io.netty.channel.epoll.EpollTcpInfo"); + return new EpollInfo( + channelClass, + infoClass.getDeclaredConstructor(), + channelClass.getMethod("tcpInfo", infoClass), + infoClass.getMethod("totalRetrans"), + infoClass.getMethod("retrans"), + infoClass.getMethod("rtt")); + } + } catch (ReflectiveOperationException e) { + log.log(Level.FINE, "Failed to initialize Epoll tcp_info reflection", e); + } finally { + log.log(Level.INFO, "Epoll available during static init of TcpMetrics:" + + "{0}", epollAvailable); + } + return null; + } + + private static final long RECORD_INTERVAL_MILLIS = TimeUnit.MINUTES.toMillis(5); + private final MetricRecorder metricRecorder; + private final Object tcpInfo; + private long lastTotalRetrans = 0; + private ScheduledFuture reportTimer; + + TcpMetrics(MetricRecorder metricRecorder) { + this.metricRecorder = metricRecorder; + + Object tcpInfo = null; + if (epollInfo != null) { + try { + tcpInfo = epollInfo.infoConstructor.newInstance(); + } catch (ReflectiveOperationException e) { + log.log(Level.FINE, "Failed to instantiate EpollTcpInfo", e); + } + } + this.tcpInfo = tcpInfo; + } + + void channelActive(Channel channel) { + List labelValues = getLabelValues(channel); + metricRecorder.addLongCounter(InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT, 1, + Collections.emptyList(), labelValues); + metricRecorder.addLongUpDownCounter(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT, 1, + Collections.emptyList(), labelValues); + scheduleNextReport(channel, true); + } + + private void scheduleNextReport(final Channel channel, boolean isInitial) { + if (epollInfo == null || !epollInfo.channelClass.isInstance(channel) || !channel.isActive()) { + return; + } + + // Initial report has a larger jitter range to spread out initial connections. + // Subsequent reports have a smaller jitter range to avoid drift. + double jitter = isInitial + ? 0.1 + ThreadLocalRandom.current().nextDouble() // 10% to 110% + : 0.9 + ThreadLocalRandom.current().nextDouble() * 0.2; // 90% to 110% + long rearmingDelay = (long) (RECORD_INTERVAL_MILLIS * jitter); + + reportTimer = channel.eventLoop().schedule(() -> { + if (channel.isActive()) { + recordTcpInfo(channel, false); + scheduleNextReport(channel, false); // Re-arm + } + }, rearmingDelay, TimeUnit.MILLISECONDS); + } + + void channelInactive(Channel channel) { + if (reportTimer != null) { + reportTimer.cancel(false); + } + List labelValues = getLabelValues(channel); + metricRecorder.addLongUpDownCounter(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT, -1, + Collections.emptyList(), labelValues); + // Final collection on close + if (epollInfo != null && epollInfo.channelClass.isInstance(channel)) { + recordTcpInfo(channel, true); + } + } + + void recordTcpInfo(Channel channel) { + recordTcpInfo(channel, false); + } + + private void recordTcpInfo(Channel channel, boolean isClose) { + if (epollInfo == null || !epollInfo.channelClass.isInstance(channel)) { + return; + } + List labelValues = getLabelValues(channel); + long totalRetrans; + long retransmits; + long rtt; + try { + epollInfo.tcpInfo.invoke(channel, tcpInfo); + totalRetrans = (Long) epollInfo.totalRetrans.invoke(tcpInfo); + retransmits = (Long) epollInfo.retransmits.invoke(tcpInfo); + rtt = (Long) epollInfo.rtt.invoke(tcpInfo); + } catch (ReflectiveOperationException e) { + log.log(Level.FINE, "Error computing TCP metrics", e); + return; + } + + long deltaTotal = totalRetrans - lastTotalRetrans; + if (deltaTotal > 0) { + metricRecorder.addLongCounter(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT, + deltaTotal, Collections.emptyList(), labelValues); + lastTotalRetrans = totalRetrans; + } + if (isClose && retransmits > 0) { + metricRecorder.addLongCounter(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT, + retransmits, Collections.emptyList(), labelValues); + } + metricRecorder.recordDoubleHistogram(InternalTcpMetrics.MIN_RTT_INSTRUMENT, + rtt / 1000000.0, // Convert microseconds to seconds + Collections.emptyList(), labelValues); + } + + @VisibleForTesting + ScheduledFuture getReportTimer() { + return reportTimer; + } + + private static List getLabelValues(Channel channel) { + String localAddress = ""; + String localPort = ""; + String peerAddress = ""; + String peerPort = ""; + + SocketAddress local = channel.localAddress(); + if (local instanceof InetSocketAddress) { + InetSocketAddress inetLocal = (InetSocketAddress) local; + if (inetLocal.getAddress() != null) { + localAddress = inetLocal.getAddress().getHostAddress(); + } else if (inetLocal.getHostString() != null) { + localAddress = inetLocal.getHostString(); + } + localPort = String.valueOf(inetLocal.getPort()); + } + + SocketAddress remote = channel.remoteAddress(); + if (remote instanceof InetSocketAddress) { + InetSocketAddress inetRemote = (InetSocketAddress) remote; + if (inetRemote.getAddress() != null) { + peerAddress = inetRemote.getAddress().getHostAddress(); + } else if (inetRemote.getHostString() != null) { + peerAddress = inetRemote.getHostString(); + } + peerPort = String.valueOf(inetRemote.getPort()); + } + + return Arrays.asList(localAddress, localPort, peerAddress, peerPort); + } +} diff --git a/netty/src/main/java/io/grpc/netty/UdsNameResolver.java b/netty/src/main/java/io/grpc/netty/UdsNameResolver.java index 8fa8ea06250..3477a458933 100644 --- a/netty/src/main/java/io/grpc/netty/UdsNameResolver.java +++ b/netty/src/main/java/io/grpc/netty/UdsNameResolver.java @@ -18,21 +18,32 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Strings.isNullOrEmpty; import com.google.common.base.Preconditions; import io.grpc.EquivalentAddressGroup; import io.grpc.NameResolver; +import io.grpc.StatusOr; import io.netty.channel.unix.DomainSocketAddress; import java.util.ArrayList; -import java.util.Collections; import java.util.List; final class UdsNameResolver extends NameResolver { private NameResolver.Listener2 listener; private final String authority; - UdsNameResolver(String authority, String targetPath) { - checkArgument(authority == null, "non-null authority not supported"); + /** + * Constructs a new instance of UdsNameResolver. + * + * @param authority authority of the 'unix:' URI to resolve, or null if target has no authority + * @param targetPath path of the 'unix:' URI to resolve + */ + UdsNameResolver(String authority, String targetPath, Args args) { + // UDS is inherently local. According to https://github.com/grpc/grpc/blob/master/doc/naming.md, + // this is expressed in the target URI either by using a blank authority, like "unix:///sock", + // or by omitting authority completely, e.g. "unix:/sock". + // TODO(jdcormie): Allow the explicit authority string "localhost"? + checkArgument(isNullOrEmpty(authority), "authority not supported: %s", authority); this.authority = targetPath; } @@ -57,8 +68,8 @@ private void resolve() { ResolutionResult.Builder resolutionResultBuilder = ResolutionResult.newBuilder(); List servers = new ArrayList<>(1); servers.add(new EquivalentAddressGroup(new DomainSocketAddress(authority))); - resolutionResultBuilder.setAddresses(Collections.unmodifiableList(servers)); - listener.onResult(resolutionResultBuilder.build()); + resolutionResultBuilder.setAddressesOrError(StatusOr.fromValue(servers)); + listener.onResult2(resolutionResultBuilder.build()); } @Override diff --git a/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java b/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java index 9f594193b4c..baf18e3d7de 100644 --- a/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java +++ b/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java @@ -20,6 +20,7 @@ import io.grpc.Internal; import io.grpc.NameResolver; import io.grpc.NameResolverProvider; +import io.grpc.Uri; import io.netty.channel.unix.DomainSocketAddress; import java.net.SocketAddress; import java.net.URI; @@ -31,10 +32,22 @@ public final class UdsNameResolverProvider extends NameResolverProvider { private static final String SCHEME = "unix"; + @Override + public NameResolver newNameResolver(Uri targetUri, NameResolver.Args args) { + if (SCHEME.equals(targetUri.getScheme())) { + return new UdsNameResolver(targetUri.getAuthority(), targetUri.getPath(), args); + } else { + return null; + } + } + @Override public UdsNameResolver newNameResolver(URI targetUri, NameResolver.Args args) { if (SCHEME.equals(targetUri.getScheme())) { - return new UdsNameResolver(targetUri.getAuthority(), getTargetPathFromUri(targetUri)); + // TODO(jdcormie): java.net.URI has a bug where getAuthority() returns null for both the + // undefined and zero-length authority. Doesn't matter for now because UdsNameResolver doesn't + // distinguish these cases. + return new UdsNameResolver(targetUri.getAuthority(), getTargetPathFromUri(targetUri), args); } else { return null; } @@ -44,6 +57,10 @@ static String getTargetPathFromUri(URI targetUri) { Preconditions.checkArgument(SCHEME.equals(targetUri.getScheme()), "scheme must be " + SCHEME); String targetPath = targetUri.getPath(); if (targetPath == null) { + // TODO(jdcormie): This incorrectly includes '?' and any characters that follow. In the + // hierarchical case ('unix:///path'), java.net.URI parses these into a query component that's + // distinct from the path. But in the present "opaque" case ('unix:/path'), what may look like + // a query is considered part of the SSP. targetPath = Preconditions.checkNotNull(targetUri.getSchemeSpecificPart(), "targetPath"); } return targetPath; diff --git a/netty/src/main/java/io/grpc/netty/Utils.java b/netty/src/main/java/io/grpc/netty/Utils.java index 96f19aab5e3..386df20ba0b 100644 --- a/netty/src/main/java/io/grpc/netty/Utils.java +++ b/netty/src/main/java/io/grpc/netty/Utils.java @@ -23,9 +23,11 @@ import static io.netty.channel.ChannelOption.SO_LINGER; import static io.netty.channel.ChannelOption.SO_TIMEOUT; import static io.netty.util.CharsetUtil.UTF_8; +import static java.nio.charset.StandardCharsets.US_ASCII; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.InternalChannelz; import io.grpc.InternalMetadata; import io.grpc.Metadata; @@ -67,7 +69,6 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.net.ssl.SSLException; @@ -91,7 +92,9 @@ class Utils { = new DefaultEventLoopGroupResource(1, "grpc-nio-boss-ELG", EventLoopGroupType.NIO); public static final Resource NIO_WORKER_EVENT_LOOP_GROUP = new DefaultEventLoopGroupResource(0, "grpc-nio-worker-ELG", EventLoopGroupType.NIO); - + private static final int HEADER_ENTRY_OVERHEAD = 32; + private static final byte[] binaryHeaderSuffixBytes = + Metadata.BINARY_HEADER_SUFFIX.getBytes(US_ASCII); public static final Resource DEFAULT_BOSS_EVENT_LOOP_GROUP; public static final Resource DEFAULT_WORKER_EVENT_LOOP_GROUP; @@ -119,10 +122,10 @@ private static final class ByteBufAllocatorPreferHeapHolder { EPOLL_DOMAIN_CLIENT_CHANNEL_TYPE = epollDomainSocketChannelType(); DEFAULT_SERVER_CHANNEL_FACTORY = new ReflectiveChannelFactory<>(epollServerChannelType()); EPOLL_EVENT_LOOP_GROUP_CONSTRUCTOR = epollEventLoopGroupConstructor(); - DEFAULT_BOSS_EVENT_LOOP_GROUP - = new DefaultEventLoopGroupResource(1, "grpc-default-boss-ELG", EventLoopGroupType.EPOLL); - DEFAULT_WORKER_EVENT_LOOP_GROUP - = new DefaultEventLoopGroupResource(0,"grpc-default-worker-ELG", EventLoopGroupType.EPOLL); + DEFAULT_BOSS_EVENT_LOOP_GROUP = new DefaultEventLoopGroupResource( + 1, "grpc-default-boss-ELG", EventLoopGroupType.EPOLL); + DEFAULT_WORKER_EVENT_LOOP_GROUP = new DefaultEventLoopGroupResource( + 0, "grpc-default-worker-ELG", EventLoopGroupType.EPOLL); } else { logger.log(Level.FINE, "Epoll is not available, using Nio.", getEpollUnavailabilityCause()); DEFAULT_SERVER_CHANNEL_FACTORY = nioServerChannelFactory(); @@ -195,6 +198,61 @@ public static Metadata convertHeaders(Http2Headers http2Headers) { return InternalMetadata.newMetadata(convertHeadersToArray(http2Headers)); } + public static int getH2HeadersSize(Http2Headers http2Headers) { + if (http2Headers instanceof GrpcHttp2InboundHeaders) { + GrpcHttp2InboundHeaders h = (GrpcHttp2InboundHeaders) http2Headers; + int size = 0; + for (int i = 0; i < h.numHeaders(); i++) { + size += h.namesAndValues()[2 * i].length; + size += + maybeAddBinaryHeaderOverhead(h.namesAndValues()[2 * i], h.namesAndValues()[2 * i + 1]); + size += HEADER_ENTRY_OVERHEAD; + } + return size; + } + + // the binary header is not decoded yet, no need to add overhead. + int size = 0; + for (Map.Entry entry : http2Headers) { + size += entry.getKey().length(); + size += entry.getValue().length(); + size += HEADER_ENTRY_OVERHEAD; + } + return size; + } + + private static int maybeAddBinaryHeaderOverhead(byte[] name, byte[] value) { + if (endsWith(name, binaryHeaderSuffixBytes)) { + return value.length * 4 / 3; + } + return value.length; + } + + private static boolean endsWith(byte[] bytes, byte[] suffix) { + if (bytes == null || suffix == null || bytes.length < suffix.length) { + return false; + } + + for (int i = 0; i < suffix.length; i++) { + if (bytes[bytes.length - suffix.length + i] != suffix[i]) { + return false; + } + } + + return true; + } + + public static boolean shouldRejectOnMetadataSizeSoftLimitExceeded( + int h2HeadersSize, int softLimitHeaderListSize, int maxHeaderListSize) { + if (h2HeadersSize < softLimitHeaderListSize) { + return false; + } + double failProbability = + (double) (h2HeadersSize - softLimitHeaderListSize) / (double) (maxHeaderListSize + - softLimitHeaderListSize); + return Math.random() < failProbability; + } + @CheckReturnValue private static byte[][] convertHeadersToArray(Http2Headers http2Headers) { // The Netty AsciiString class is really just a wrapper around a byte[] and supports diff --git a/netty/src/main/java/io/grpc/netty/X509AuthorityVerifier.java b/netty/src/main/java/io/grpc/netty/X509AuthorityVerifier.java new file mode 100644 index 00000000000..a2df3dbc431 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/X509AuthorityVerifier.java @@ -0,0 +1,108 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.grpc.Status; +import io.grpc.internal.AuthorityVerifier; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import javax.annotation.Nonnull; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.X509TrustManager; + +final class X509AuthorityVerifier implements AuthorityVerifier { + private final SSLEngine sslEngine; + private final X509TrustManager x509ExtendedTrustManager; + + private static final Method checkServerTrustedMethod; + + static { + Method method = null; + try { + Class x509ExtendedTrustManagerClass = + Class.forName("javax.net.ssl.X509ExtendedTrustManager"); + method = x509ExtendedTrustManagerClass.getMethod("checkServerTrusted", + X509Certificate[].class, String.class, SSLEngine.class); + } catch (ClassNotFoundException e) { + // Per-rpc authority overriding via call options will be disallowed. + } catch (NoSuchMethodException e) { + // Should never happen since X509ExtendedTrustManager was introduced in Android API level 24 + // along with checkServerTrusted. + } + checkServerTrustedMethod = method; + } + + public X509AuthorityVerifier(SSLEngine sslEngine, X509TrustManager x509ExtendedTrustManager) { + this.sslEngine = checkNotNull(sslEngine, "sslEngine"); + this.x509ExtendedTrustManager = x509ExtendedTrustManager; + } + + @Override + public Status verifyAuthority(@Nonnull String authority) { + if (x509ExtendedTrustManager == null) { + return Status.UNAVAILABLE.withDescription( + "Can't allow authority override in rpc when X509ExtendedTrustManager" + + " is not available"); + } + Status peerVerificationStatus; + try { + // Because the authority pseudo-header can contain a port number: + // https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2.3 + verifyAuthorityAllowedForPeerCert(removeAnyPortNumber(authority)); + peerVerificationStatus = Status.OK; + } catch (SSLPeerUnverifiedException | CertificateException | InvocationTargetException + | IllegalAccessException | IllegalStateException e) { + peerVerificationStatus = Status.UNAVAILABLE.withDescription( + String.format("Peer hostname verification during rpc failed for authority '%s'", + authority)).withCause(e); + } + return peerVerificationStatus; + } + + private String removeAnyPortNumber(String authority) { + int closingSquareBracketIndex = authority.lastIndexOf(']'); + int portNumberSeperatorColonIndex = authority.lastIndexOf(':'); + if (portNumberSeperatorColonIndex > closingSquareBracketIndex) { + return authority.substring(0, portNumberSeperatorColonIndex); + } + return authority; + } + + private void verifyAuthorityAllowedForPeerCert(String authority) + throws SSLPeerUnverifiedException, CertificateException, InvocationTargetException, + IllegalAccessException { + SSLEngine sslEngineWrapper = new ProtocolNegotiators.SslEngineWrapper(sslEngine, authority); + // The typecasting of Certificate to X509Certificate should work because this method will only + // be called when using TLS and thus X509. + Certificate[] peerCertificates = sslEngine.getSession().getPeerCertificates(); + X509Certificate[] x509PeerCertificates = new X509Certificate[peerCertificates.length]; + for (int i = 0; i < peerCertificates.length; i++) { + x509PeerCertificates[i] = (X509Certificate) peerCertificates[i]; + } + if (checkServerTrustedMethod == null) { + throw new IllegalStateException("checkServerTrustedMethod not found"); + } + checkServerTrustedMethod.invoke( + x509ExtendedTrustManager, x509PeerCertificates, "UNKNOWN", sslEngineWrapper); + } +} diff --git a/netty/src/main/java/io/grpc/netty/package-info.java b/netty/src/main/java/io/grpc/netty/package-info.java index 54595b38573..d1d7b87cf51 100644 --- a/netty/src/main/java/io/grpc/netty/package-info.java +++ b/netty/src/main/java/io/grpc/netty/package-info.java @@ -18,5 +18,5 @@ * The main transport implementation based on Netty, * for both the client and the server. */ -@javax.annotation.CheckReturnValue +@com.google.errorprone.annotations.CheckReturnValue package io.grpc.netty; diff --git a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java index f34e336553b..66591cda153 100644 --- a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java +++ b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java @@ -436,16 +436,13 @@ public void onFileReloadingTrustManagerBadInitialContentTest() throws Exception } @Test - public void keyManagerAliasesTest() { + public void keyManagerAliasesTest() throws Exception { AdvancedTlsX509KeyManager km = new AdvancedTlsX509KeyManager(); - assertArrayEquals( - new String[] {"default"}, km.getClientAliases("", null)); - assertEquals( - "default", km.chooseClientAlias(new String[] {"default"}, null, null)); - assertArrayEquals( - new String[] {"default"}, km.getServerAliases("", null)); - assertEquals( - "default", km.chooseServerAlias("default", null, null)); + km.updateIdentityCredentials(serverCert0, serverKey0); + assertArrayEquals(new String[] {"key-1"}, km.getClientAliases("", null)); + assertEquals("key-1", km.chooseClientAlias(new String[] {"key-1"}, null, null)); + assertArrayEquals(new String[] {"key-1"}, km.getServerAliases("", null)); + assertEquals("key-1", km.chooseServerAlias("key-1", null, null)); } @Test diff --git a/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java b/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java index 1037dab4712..b19f247b5cf 100644 --- a/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java @@ -40,7 +40,6 @@ import io.netty.buffer.UnpooledByteBufAllocator; import java.util.Collection; import java.util.List; -import java.util.stream.Collectors; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -53,9 +52,12 @@ @RunWith(Enclosed.class) public class NettyAdaptiveCumulatorTest { + private static boolean usingPre4_1_111_Netty() { + return false; // Disabled detection because it was unreliable + } private static Collection cartesianProductParams(List... lists) { - return Lists.cartesianProduct(lists).stream().map(List::toArray).collect(Collectors.toList()); + return Lists.transform(Lists.cartesianProduct(lists), List::toArray); } @RunWith(JUnit4.class) @@ -386,9 +388,8 @@ public void mergeWithCompositeTail_tailExpandable_reallocateInMemory() { } private void assertTailExpanded(String expectedTailReadableData, int expectedNewTailCapacity) { - if (!GrpcHttp2ConnectionHandler.usingPre4_1_111_Netty()) { - return; // Netty 4.1.111 doesn't work with NettyAdaptiveCumulator - } + assume().withMessage("Netty 4.1.111 doesn't work with NettyAdaptiveCumulator") + .that(usingPre4_1_111_Netty()).isTrue(); int originalNumComponents = composite.numComponents(); // Handle the case when reader index is beyond all readable bytes of the cumulation. @@ -629,9 +630,8 @@ public void mergeWithCompositeTail_outOfSyncComposite() { alloc.compositeBuffer(8).addFlattenedComponents(true, composite1); assertThat(composite2.toString(US_ASCII)).isEqualTo("01234"); - if (!GrpcHttp2ConnectionHandler.usingPre4_1_111_Netty()) { - return; // Netty 4.1.111 doesn't work with NettyAdaptiveCumulator - } + assume().withMessage("Netty 4.1.111 doesn't work with NettyAdaptiveCumulator") + .that(usingPre4_1_111_Netty()).isTrue(); // The previous operation does not adjust the read indexes of the underlying buffers, // only the internal Component offsets. When the cumulator attempts to append the input to diff --git a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java index 5789d275c07..95d54d13b82 100644 --- a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; @@ -39,17 +40,13 @@ import java.net.SocketAddress; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLException; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class NettyChannelBuilderTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); private final SslContext noSslContext = null; private void shutdown(ManagedChannel mc) throws Exception { @@ -107,10 +104,9 @@ private void overrideAuthorityIsReadableHelper(NettyChannelBuilder builder, public void failOverrideInvalidAuthority() { NettyChannelBuilder builder = new NettyChannelBuilder(getTestSocketAddress()); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority:"); - - builder.overrideAuthority("[invalidauthority"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.overrideAuthority("[invalidauthority")); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [invalidauthority"); } @Test @@ -128,20 +124,18 @@ public void enableCheckAuthorityFailOverrideInvalidAuthority() { NettyChannelBuilder builder = new NettyChannelBuilder(getTestSocketAddress()) .disableCheckAuthority() .enableCheckAuthority(); - - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority:"); - builder.overrideAuthority("[invalidauthority"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.overrideAuthority("[invalidauthority")); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [invalidauthority"); } @Test public void failInvalidAuthority() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid host or port"); - @SuppressWarnings("AddressSelection") // We actually expect zero addresses! - Object unused = - NettyChannelBuilder.forAddress(new InetSocketAddress("invalid_authority", 1234)); + InetSocketAddress address = new InetSocketAddress("invalid_authority", 1234); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> NettyChannelBuilder.forAddress(address)); + assertThat(e).hasMessageThat().isEqualTo("Invalid host or port: invalid_authority 1234"); } @Test @@ -155,10 +149,10 @@ public void failIfSslContextIsNotClient() { SslContext sslContext = mock(SslContext.class); NettyChannelBuilder builder = new NettyChannelBuilder(getTestSocketAddress()); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Server SSL context can not be used for client channel"); - - builder.sslContext(sslContext); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.sslContext(sslContext)); + assertThat(e).hasMessageThat() + .isEqualTo("Server SSL context can not be used for client channel"); } @Test @@ -166,10 +160,10 @@ public void failNegotiationTypeWithChannelCredentials_target() { NettyChannelBuilder builder = NettyChannelBuilder.forTarget( "fakeTarget", InsecureChannelCredentials.create()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("Cannot change security when using ChannelCredentials"); - - builder.negotiationType(NegotiationType.TLS); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> builder.negotiationType(NegotiationType.TLS)); + assertThat(e).hasMessageThat() + .isEqualTo("Cannot change security when using ChannelCredentials"); } @Test @@ -177,10 +171,10 @@ public void failNegotiationTypeWithChannelCredentials_socketAddress() { NettyChannelBuilder builder = NettyChannelBuilder.forAddress( getTestSocketAddress(), InsecureChannelCredentials.create()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("Cannot change security when using ChannelCredentials"); - - builder.negotiationType(NegotiationType.TLS); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> builder.negotiationType(NegotiationType.TLS)); + assertThat(e).hasMessageThat() + .isEqualTo("Cannot change security when using ChannelCredentials"); } @Test @@ -205,10 +199,9 @@ public void createProtocolNegotiatorByType_plaintextUpgrade() { @Test public void createProtocolNegotiatorByType_tlsWithNoContext() { - thrown.expect(NullPointerException.class); - NettyChannelBuilder.createProtocolNegotiatorByType( - NegotiationType.TLS, - noSslContext, null); + assertThrows(NullPointerException.class, + () -> NettyChannelBuilder.createProtocolNegotiatorByType( + NegotiationType.TLS, noSslContext, null)); } @Test @@ -245,38 +238,40 @@ public void createProtocolNegotiatorByType_tlsWithAuthorityFallback() throws SSL public void negativeKeepAliveTime() { NettyChannelBuilder builder = NettyChannelBuilder.forTarget("fakeTarget"); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("keepalive time must be positive"); - builder.keepAliveTime(-1L, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.keepAliveTime(-1L, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("keepalive time must be positive"); } @Test public void negativeKeepAliveTimeout() { NettyChannelBuilder builder = NettyChannelBuilder.forTarget("fakeTarget"); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("keepalive timeout must be positive"); - builder.keepAliveTimeout(-1L, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.keepAliveTimeout(-1L, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("keepalive timeout must be positive"); } @Test public void assertEventLoopAndChannelType_onlyGroupProvided() { NettyChannelBuilder builder = NettyChannelBuilder.forTarget("fakeTarget"); builder.eventLoopGroup(mock(EventLoopGroup.class)); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("Both EventLoopGroup and ChannelType should be provided"); - builder.assertEventLoopAndChannelType(); + IllegalStateException e = assertThrows(IllegalStateException.class, + builder::assertEventLoopAndChannelType); + assertThat(e).hasMessageThat() + .isEqualTo("Both EventLoopGroup and ChannelType should be provided or neither should be"); } @Test public void assertEventLoopAndChannelType_onlyTypeProvided() { NettyChannelBuilder builder = NettyChannelBuilder.forTarget("fakeTarget"); builder.channelType(LocalChannel.class, LocalAddress.class); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("Both EventLoopGroup and ChannelType should be provided"); - builder.assertEventLoopAndChannelType(); + IllegalStateException e = assertThrows(IllegalStateException.class, + builder::assertEventLoopAndChannelType); + assertThat(e).hasMessageThat() + .isEqualTo("Both EventLoopGroup and ChannelType should be provided or neither should be"); } @Test @@ -288,10 +283,11 @@ public Channel newChannel() { return null; } }); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("Both EventLoopGroup and ChannelType should be provided"); - builder.assertEventLoopAndChannelType(); + IllegalStateException e = assertThrows(IllegalStateException.class, + builder::assertEventLoopAndChannelType); + assertThat(e).hasMessageThat() + .isEqualTo("Both EventLoopGroup and ChannelType should be provided or neither should be"); } @Test diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index 73988f773cb..9f6be9a2f3e 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -28,7 +28,6 @@ import static io.grpc.netty.Utils.STATUS_OK; import static io.grpc.netty.Utils.TE_HEADER; import static io.grpc.netty.Utils.TE_TRAILERS; -import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -36,6 +35,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; @@ -47,6 +47,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import com.google.common.base.Stopwatch; +import com.google.common.base.Strings; import com.google.common.base.Supplier; import com.google.common.base.Ticker; import com.google.common.collect.ImmutableList; @@ -56,16 +57,18 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Metadata; +import io.grpc.MetricRecorder; import io.grpc.Status; -import io.grpc.StatusException; import io.grpc.internal.AbstractStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.ClientTransport; import io.grpc.internal.ClientTransport.PingCallback; +import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; import io.grpc.internal.KeepAliveManager; import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.StreamListener; import io.grpc.internal.TransportTracer; @@ -89,10 +92,12 @@ import io.netty.handler.codec.http2.Http2Stream; import io.netty.util.AsciiString; import java.io.InputStream; +import java.security.cert.CertificateException; import java.text.MessageFormat; import java.util.LinkedList; import java.util.List; import java.util.Queue; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Handler; @@ -122,6 +127,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBaseany()); - + doAnswer((attributes) -> Attributes.newBuilder().set( + GrpcAttributes.ATTR_AUTHORITY_VERIFIER, + (authority) -> Status.OK).build()) + .when(listener) + .filterTransport(ArgumentMatchers.any(Attributes.class)); lifecycleManager = new ClientTransportLifecycleManager(listener); // This mocks the keepalive manager only for there's in which we verify it. For other tests // it'll be null which will be testing if we behave correctly when it's not present. @@ -215,6 +225,37 @@ public Void answer(InvocationOnMock invocation) throws Throwable { // Simulate receipt of initial remote settings. ByteBuf serializedSettings = serializeSettings(new Http2Settings()); channelRead(serializedSettings); + channel().releaseOutbound(); + } + + @Test + @SuppressWarnings("InlineMeInliner") + public void sendLargerThanSoftLimitHeaderMayFail() throws Exception { + maxHeaderListSize = 8000; + softLimitHeaderListSize = 2000; + manualSetUp(); + + createStream(); + // total head size of 7999, soft limit = 2000 and max = 8000. + // This header has 5999/6000 chance to be rejected. + Http2Headers headers = new DefaultHttp2Headers() + .scheme(HTTPS) + .authority(as("www.fake.com")) + .path(as("/fakemethod")) + .method(HTTP_METHOD) + .add(as("auth"), as("sometoken")) + .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC) + .add(TE_HEADER, TE_TRAILERS) + .add("large-field", Strings.repeat("a", 7620)); // String.repeat() requires Java 11 + + ByteBuf headersFrame = headersFrame(STREAM_ID, headers); + channelRead(headersFrame); + ArgumentCaptor statusArgumentCaptor = ArgumentCaptor.forClass(Status.class); + verify(streamListener).closed(statusArgumentCaptor.capture(), eq(PROCESSED), + any(Metadata.class)); + assertThat(statusArgumentCaptor.getValue().getCode()).isEqualTo(Status.Code.RESOURCE_EXHAUSTED); + assertThat(statusArgumentCaptor.getValue().getDescription()).contains( + "exceeded Metadata size soft limit"); } @Test @@ -228,7 +269,7 @@ public void cancelBufferedStreamShouldChangeClientStreamStatus() throws Exceptio // Cancel the stream. cancelStream(Status.CANCELLED); - assertTrue(createFuture.isSuccess()); + assertFalse(createFuture.isSuccess()); verify(streamListener).closed(eq(Status.CANCELLED), same(PROCESSED), any(Metadata.class)); } @@ -236,7 +277,7 @@ public void cancelBufferedStreamShouldChangeClientStreamStatus() throws Exceptio public void createStreamShouldSucceed() throws Exception { createStream(); verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(grpcHeaders), eq(0), - eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class)); + eq(false), any(ChannelPromise.class)); } @Test @@ -271,7 +312,7 @@ public void cancelWhileBufferedShouldSucceed() throws Exception { ChannelFuture cancelFuture = cancelStream(Status.CANCELLED); assertTrue(cancelFuture.isSuccess()); assertTrue(createFuture.isDone()); - assertTrue(createFuture.isSuccess()); + assertFalse(createFuture.isSuccess()); } /** @@ -310,11 +351,12 @@ public void sendFrameShouldSucceed() throws Exception { createStream(); // Send a frame and verify that it was written. + ByteBuf content = content(); ChannelFuture future - = enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true)); + = enqueue(new SendGrpcFrameCommand(streamTransportState, content, true)); assertTrue(future.isSuccess()); - verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(true), + verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(true), any(ChannelPromise.class)); verify(mockKeepAliveManager, times(1)).onTransportActive(); // onStreamActive verifyNoMoreInteractions(mockKeepAliveManager); @@ -412,6 +454,26 @@ public void receivedAbruptGoAwayShouldFailRacingQueuedStreamid() throws Exceptio assertTrue(future.isDone()); } + @Test + public void receivedAbruptGoAwayShouldFailRacingQueuedIoStreamid() throws Exception { + // Purposefully avoid flush(), since we want the write to not actually complete. + // EmbeddedChannel doesn't support flow control, so this is the next closest approximation. + ChannelFuture future = channel().write( + newCreateStreamCommand(grpcHeaders, streamTransportState)); + // Read a GOAWAY that indicates our stream can't be sent + channelRead(goAwayFrame(0, 0 /* NO_ERROR */, Unpooled.copiedBuffer("this is a test", UTF_8))); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); + verify(streamListener).closed(captor.capture(), same(REFUSED), + ArgumentMatchers.notNull()); + assertEquals(Status.UNAVAILABLE.getCode(), captor.getValue().getCode()); + assertEquals( + "Abrupt GOAWAY closed sent stream. HTTP/2 error code: NO_ERROR, " + + "debug data: this is a test", + captor.getValue().getDescription()); + assertTrue(future.isDone()); + } + @Test public void receivedGoAway_shouldFailBufferedStreamsExceedingMaxConcurrentStreams() throws Exception { @@ -704,7 +766,7 @@ public void exhaustedStreamsShouldFail() throws Exception { public void nonExistentStream() throws Exception { Status status = Status.INTERNAL.withDescription("zz"); - lifecycleManager.notifyShutdown(status); + lifecycleManager.notifyShutdown(status, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // Stream creation can race with the transport shutting down, with the create command already // enqueued. ChannelFuture future1 = createStream(); @@ -770,9 +832,7 @@ public void ping_failsWhenChannelCloses() throws Exception { handler().channelInactive(ctx()); // ping failed on channel going inactive assertEquals(1, callback.invocationCount); - assertTrue(callback.failureCause instanceof StatusException); - assertEquals(Status.Code.UNAVAILABLE, - ((StatusException) callback.failureCause).getStatus().getCode()); + assertEquals(Status.Code.UNAVAILABLE, callback.failureCause.getCode()); // A failed ping is still counted assertEquals(1, transportTracer.getStats().keepAlivesSent); } @@ -885,6 +945,159 @@ public void exceptionCaughtShouldCloseConnection() throws Exception { assertFalse(channel().isOpen()); } + @Test + public void missingAuthorityHeader_streamCreationShouldFail() throws Exception { + Http2Headers grpcHeadersWithoutAuthority = new DefaultHttp2Headers() + .scheme(HTTPS) + .path(as("/fakemethod")) + .method(HTTP_METHOD) + .add(as("auth"), as("sometoken")) + .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC) + .add(TE_HEADER, TE_TRAILERS); + ChannelFuture channelFuture = enqueue(newCreateStreamCommand( + grpcHeadersWithoutAuthority, streamTransportState)); + try { + channelFuture.get(); + fail("Expected stream creation failure"); + } catch (ExecutionException e) { + assertThat(e.getCause().getMessage()).isEqualTo("UNAVAILABLE: Missing authority header"); + } + } + + @Test + public void missingAuthorityVerifierInAttributes_streamCreationShouldFail() throws Exception { + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + StreamListener.MessageProducer producer = + (StreamListener.MessageProducer) invocation.getArguments()[0]; + InputStream message; + while ((message = producer.next()) != null) { + streamListenerMessageQueue.add(message); + } + return null; + } + }) + .when(streamListener) + .messagesAvailable(ArgumentMatchers.any()); + doAnswer((attributes) -> Attributes.EMPTY) + .when(listener) + .filterTransport(ArgumentMatchers.any(Attributes.class)); + lifecycleManager = new ClientTransportLifecycleManager(listener); + // This mocks the keepalive manager only for there's in which we verify it. For other tests + // it'll be null which will be testing if we behave correctly when it's not present. + if (setKeepaliveManagerFor.contains(testNameRule.getMethodName())) { + mockKeepAliveManager = mock(KeepAliveManager.class); + } + + initChannel(new GrpcHttp2ClientHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE)); + streamTransportState = new TransportStateImpl( + handler(), + channel().eventLoop(), + DEFAULT_MAX_MESSAGE_SIZE, + transportTracer); + streamTransportState.setListener(streamListener); + + grpcHeaders = new DefaultHttp2Headers() + .scheme(HTTPS) + .authority(as("www.fake.com")) + .path(as("/fakemethod")) + .method(HTTP_METHOD) + .add(as("auth"), as("sometoken")) + .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC) + .add(TE_HEADER, TE_TRAILERS); + + // Simulate receipt of initial remote settings. + ByteBuf serializedSettings = serializeSettings(new Http2Settings()); + channelRead(serializedSettings); + channel().releaseOutbound(); + + ChannelFuture channelFuture = createStream(); + try { + channelFuture.get(); + fail("Expected stream creation failure"); + } catch (ExecutionException e) { + assertThat(e.getCause().getMessage()).isEqualTo( + "UNAVAILABLE: Authority verifier not found to verify authority"); + } + } + + @Test + public void authorityVerificationSuccess_streamCreationSucceeds() throws Exception { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + ChannelFuture channelFuture = createStream(); + channelFuture.get(); + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void authorityVerificationFailure_streamCreationFails() throws Exception { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + StreamListener.MessageProducer producer = + (StreamListener.MessageProducer) invocation.getArguments()[0]; + InputStream message; + while ((message = producer.next()) != null) { + streamListenerMessageQueue.add(message); + } + return null; + } + }) + .when(streamListener) + .messagesAvailable(ArgumentMatchers.any()); + doAnswer((attributes) -> Attributes.newBuilder().set( + GrpcAttributes.ATTR_AUTHORITY_VERIFIER, + (authority) -> Status.UNAVAILABLE.withCause( + new CertificateException("Peer verification failed"))).build()) + .when(listener) + .filterTransport(ArgumentMatchers.any(Attributes.class)); + lifecycleManager = new ClientTransportLifecycleManager(listener); + // This mocks the keepalive manager only for there's in which we verify it. For other tests + // it'll be null which will be testing if we behave correctly when it's not present. + if (setKeepaliveManagerFor.contains(testNameRule.getMethodName())) { + mockKeepAliveManager = mock(KeepAliveManager.class); + } + + initChannel(new GrpcHttp2ClientHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE)); + streamTransportState = new TransportStateImpl( + handler(), + channel().eventLoop(), + DEFAULT_MAX_MESSAGE_SIZE, + transportTracer); + streamTransportState.setListener(streamListener); + + grpcHeaders = new DefaultHttp2Headers() + .scheme(HTTPS) + .authority(as("www.fake.com")) + .path(as("/fakemethod")) + .method(HTTP_METHOD) + .add(as("auth"), as("sometoken")) + .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC) + .add(TE_HEADER, TE_TRAILERS); + + // Simulate receipt of initial remote settings. + ByteBuf serializedSettings = serializeSettings(new Http2Settings()); + channelRead(serializedSettings); + channel().releaseOutbound(); + + ChannelFuture channelFuture = createStream(); + try { + channelFuture.get(); + fail("Expected stream creation failure"); + } catch (ExecutionException e) { + assertThat(e.getMessage()).isEqualTo("io.grpc.InternalStatusRuntimeException: UNAVAILABLE"); + } + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + @Override protected void makeStream() throws Exception { createStream(); @@ -946,13 +1159,15 @@ public Stopwatch get() { false, flowControlWindow, maxHeaderListSize, + softLimitHeaderListSize, stopwatchSupplier, tooManyPingsRunnable, transportTracer, Attributes.EMPTY, "someauthority", null, - fakeClock().getTicker()); + fakeClock().getTicker(), + new MetricRecorder() {}); } @Override @@ -973,7 +1188,7 @@ private static CreateStreamCommand newCreateStreamCommand( private static class PingCallbackImpl implements ClientTransport.PingCallback { int invocationCount; long roundTripTime; - Throwable failureCause; + Status failureCause; @Override public void onSuccess(long roundTripTimeNanos) { @@ -982,7 +1197,7 @@ public void onSuccess(long roundTripTimeNanos) { } @Override - public void onFailure(Throwable cause) { + public void onFailure(Status cause) { invocationCount++; this.failureCause = cause; } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java index 2a5a0df279a..4dd24c3fd4d 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java @@ -46,6 +46,7 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.Iterables; import com.google.common.io.BaseEncoding; import io.grpc.CallOptions; import io.grpc.InternalStatus; @@ -232,18 +233,20 @@ public void writeFrameFutureFailedShouldCancelRpc() { // Verify that failed SendGrpcFrameCommand results in immediate CancelClientStreamCommand. inOrder.verify(writeQueue).enqueue(any(CancelClientStreamCommand.class), eq(true)); // Verify that any other failures do not produce another CancelClientStreamCommand in the queue. - inOrder.verify(writeQueue, atLeast(1)).enqueue(any(SendGrpcFrameCommand.class), eq(false)); + inOrder.verify(writeQueue, atLeast(0)).enqueue(any(SendGrpcFrameCommand.class), eq(false)); inOrder.verify(writeQueue).enqueue(any(SendGrpcFrameCommand.class), eq(true)); inOrder.verifyNoMoreInteractions(); // Get the CancelClientStreamCommand written to the queue. Above we verified that there is // only one CancelClientStreamCommand enqueued, and is the third enqueued command (create, // frame write failure, cancel). - CancelClientStreamCommand cancelCommand = Mockito.mockingDetails(writeQueue).getInvocations() - // Get enqueue() innovations only - .stream().filter(invocation -> invocation.getMethod().getName().equals("enqueue")) + CancelClientStreamCommand cancelCommand = Iterables.get( + Iterables.filter( + Mockito.mockingDetails(writeQueue).getInvocations(), + // Get enqueue() innovations only + invocation -> invocation.getMethod().getName().equals("enqueue")), // Get the third invocation of enqueue() - .skip(2).findFirst().get() + 2) // Get the first argument (QueuedCommand command) .getArgument(0); diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 9777bb0926c..7023acc947c 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -37,12 +37,16 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.base.Optional; import com.google.common.base.Strings; import com.google.common.base.Ticker; import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ChannelLogger; @@ -52,13 +56,16 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.Marshaller; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StatusException; +import io.grpc.TlsChannelCredentials; import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientTransport; +import io.grpc.internal.DisconnectError; import io.grpc.internal.FakeClock; import io.grpc.internal.FixedObjectPool; import io.grpc.internal.GrpcUtil; @@ -73,6 +80,7 @@ import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; import io.grpc.netty.NettyTestUtil.TrackingObjectPoolForTest; import io.grpc.testing.TlsTesting; +import io.grpc.util.CertificateUtils; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelConfig; @@ -94,12 +102,18 @@ import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; +import io.netty.util.ReferenceCountUtil; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.InvocationTargetException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -113,6 +127,12 @@ import javax.annotation.Nullable; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -127,11 +147,15 @@ * Tests for {@link NettyClientTransport}. */ @RunWith(JUnit4.class) +@IgnoreJRERequirement public class NettyClientTransportTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); private static final SslContext SSL_CONTEXT = createSslContext(); + @SuppressWarnings("InlineMeInliner") // Requires Java 11 + private static final String LONG_STRING_OF_A = Strings.repeat("a", 128); + @Mock private ManagedClientTransport.Listener clientTransportListener; @@ -186,6 +210,7 @@ public void addDefaultUserAgent() throws Exception { startServer(); NettyClientTransport transport = newTransport(newNegotiator()); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); // Send a single RPC and wait for the response. new Rpc(transport).halfClose().waitForResponse(); @@ -198,18 +223,37 @@ public void addDefaultUserAgent() throws Exception { } @Test - public void setSoLingerChannelOption() throws IOException { + public void setSoLingerChannelOption() throws IOException, GeneralSecurityException { startServer(); Map, Object> channelOptions = new HashMap<>(); // set SO_LINGER option int soLinger = 123; channelOptions.put(ChannelOption.SO_LINGER, soLinger); NettyClientTransport transport = new NettyClientTransport( - address, new ReflectiveChannelFactory<>(NioSocketChannel.class), channelOptions, group, - newNegotiator(), false, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, - GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1L, false, authority, - null /* user agent */, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, - new SocketPicker(), new FakeChannelLogger(), false, Ticker.systemTicker()); + address, + new ReflectiveChannelFactory<>(NioSocketChannel.class), + channelOptions, + group, + newNegotiator(), + false, + DEFAULT_WINDOW_SIZE, + DEFAULT_MAX_MESSAGE_SIZE, + GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, + GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, + KEEPALIVE_TIME_NANOS_DISABLED, + 1L, + false, + authority, + null /* user agent */, + tooManyPingsRunnable, + new TransportTracer(), + Attributes.EMPTY, + new SocketPicker(), + new FakeChannelLogger(), + false, + new MetricRecorder() { + }, + Ticker.systemTicker()); transports.add(transport); callMeMaybe(transport.start(clientTransportListener)); @@ -225,6 +269,7 @@ public void overrideDefaultUserAgent() throws Exception { NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); new Rpc(transport, new Metadata()).halfClose().waitForResponse(); @@ -242,6 +287,7 @@ public void maxMessageSizeShouldBeEnforced() throws Throwable { NettyClientTransport transport = newTransport(newNegotiator(), 1, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null, true); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); try { // Send a single RPC and wait for the response. @@ -268,6 +314,7 @@ public void creatingMultipleTlsTransportsShouldSucceed() throws Exception { NettyClientTransport transport = newTransport(negotiator); callMeMaybe(transport.start(clientTransportListener)); } + verify(clientTransportListener, timeout(5000).times(2)).transportReady(); // Send a single RPC on each transport. final List rpcs = new ArrayList<>(transports.size()); @@ -297,6 +344,7 @@ public void run() { failureStatus.asRuntimeException()); } }); + verify(clientTransportListener, timeout(5000)).transportTerminated(); Rpc rpc = new Rpc(transport).halfClose(); try { @@ -327,9 +375,10 @@ public void tlsNegotiationFailurePropagatesToStatus() throws Exception { .trustManager(caCert) .keyManager(clientCert, clientKey) .build(); - ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext); + ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, null); final NettyClientTransport transport = newTransport(negotiator); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportTerminated(); Rpc rpc = new Rpc(transport).halfClose(); try { @@ -359,6 +408,7 @@ public void channelExceptionDuringNegotiatonPropagatesToStatus() throws Exceptio callMeMaybe(transport.start(clientTransportListener)); final Status failureStatus = Status.UNAVAILABLE.withDescription("oh noes!"); transport.channel().pipeline().fireExceptionCaught(failureStatus.asRuntimeException()); + verify(clientTransportListener, timeout(5000)).transportTerminated(); Rpc rpc = new Rpc(transport).halfClose(); try { @@ -390,6 +440,7 @@ public void run() { } } }); + verify(clientTransportListener, timeout(5000)).transportTerminated(); Rpc rpc = new Rpc(transport).halfClose(); try { @@ -409,6 +460,7 @@ public void bufferedStreamsShouldBeClosedWhenConnectionTerminates() throws Excep NettyClientTransport transport = newTransport(newNegotiator()); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); // Send a dummy RPC in order to ensure that the updated SETTINGS_MAX_CONCURRENT_STREAMS // has been received by the remote endpoint. @@ -454,12 +506,30 @@ public void failingToConstructChannelShouldFailGracefully() throws Exception { address = TestUtils.testServerAddress(new InetSocketAddress(12345)); authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort()); NettyClientTransport transport = new NettyClientTransport( - address, new ReflectiveChannelFactory<>(CantConstructChannel.class), - new HashMap, Object>(), group, - newNegotiator(), false, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, - GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1, false, authority, - null, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker(), - new FakeChannelLogger(), false, Ticker.systemTicker()); + address, + new ReflectiveChannelFactory<>(CantConstructChannel.class), + new HashMap, Object>(), + group, + newNegotiator(), + false, + DEFAULT_WINDOW_SIZE, + DEFAULT_MAX_MESSAGE_SIZE, + GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, + GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, + KEEPALIVE_TIME_NANOS_DISABLED, + 1, + false, + authority, + null, + tooManyPingsRunnable, + new TransportTracer(), + Attributes.EMPTY, + new SocketPicker(), + new FakeChannelLogger(), + false, + new MetricRecorder() { + }, + Ticker.systemTicker()); transports.add(transport); // Should not throw @@ -485,8 +555,8 @@ public void onSuccess(long roundTripTimeNanos) { } @Override - public void onFailure(Throwable cause) { - pingResult.setException(cause); + public void onFailure(Status cause) { + pingResult.setException(cause.asException()); } }; transport.ping(pingCallback, clock.getScheduledExecutorService()); @@ -543,6 +613,7 @@ public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception { NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, 1, null, true); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); try { // Send a single RPC and wait for the response. @@ -560,9 +631,6 @@ public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception { @Test public void huffmanCodingShouldNotBePerformed() throws Exception { - @SuppressWarnings("InlineMeInliner") // Requires Java 11 - String longStringOfA = Strings.repeat("a", 128); - negotiator = ProtocolNegotiators.serverPlaintext(); startServer(); @@ -573,9 +641,10 @@ public void huffmanCodingShouldNotBePerformed() throws Exception { Metadata headers = new Metadata(); headers.put(Metadata.Key.of("test", Metadata.ASCII_STRING_MARSHALLER), - longStringOfA); + LONG_STRING_OF_A); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); AtomicBoolean foundExpectedHeaderBytes = new AtomicBoolean(false); @@ -584,7 +653,7 @@ public void huffmanCodingShouldNotBePerformed() throws Exception { public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { if (msg instanceof ByteBuf) { - if (((ByteBuf) msg).toString(StandardCharsets.UTF_8).contains(longStringOfA)) { + if (((ByteBuf) msg).toString(StandardCharsets.UTF_8).contains(LONG_STRING_OF_A)) { foundExpectedHeaderBytes.set(true); } } @@ -599,12 +668,54 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) } } + @Test + public void huffmanCodingShouldNotBePerformedOnServer() throws Exception { + negotiator = ProtocolNegotiators.serverPlaintext(); + + Metadata responseHeaders = new Metadata(); + responseHeaders.put(Metadata.Key.of("test", Metadata.ASCII_STRING_MARSHALLER), + LONG_STRING_OF_A); + + startServer(new EchoServerListener(responseHeaders)); + + NettyClientTransport transport = newTransport(ProtocolNegotiators.plaintext(), + DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null, false, + TimeUnit.SECONDS.toNanos(10L), TimeUnit.SECONDS.toNanos(1L), + new ReflectiveChannelFactory<>(NioSocketChannel.class), group); + + callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); + + AtomicBoolean foundExpectedHeaderBytes = new AtomicBoolean(false); + + // Add a handler to the client pipeline to inspect server's response + transport.channel().pipeline().addFirst(new ChannelDuplexHandler() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof ByteBuf) { + String data = ((ByteBuf) msg).toString(StandardCharsets.UTF_8); + if (data.contains(LONG_STRING_OF_A)) { + foundExpectedHeaderBytes.set(true); + } + } + super.channelRead(ctx, msg); + } + }); + + new Rpc(transport).halfClose().waitForResponse(); + + if (!foundExpectedHeaderBytes.get()) { + fail("expected to find UTF-8 encoded 'a's in the response header sent by the server"); + } + } + @Test public void maxHeaderListSizeShouldBeEnforcedOnServer() throws Exception { startServer(100, 1); NettyClientTransport transport = newTransport(newNegotiator()); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); try { // Send a single RPC and wait for the response. @@ -649,6 +760,7 @@ public void clientStreamGetsAttributes() throws Exception { startServer(); NettyClientTransport transport = newTransport(newNegotiator()); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); Rpc rpc = new Rpc(transport).halfClose(); rpc.waitForResponse(); @@ -667,6 +779,7 @@ public void keepAliveEnabled() throws Exception { NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null /* user agent */, true /* keep alive */); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); Rpc rpc = new Rpc(transport).halfClose(); rpc.waitForResponse(); @@ -679,6 +792,7 @@ public void keepAliveDisabled() throws Exception { NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null /* user agent */, false /* keep alive */); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); Rpc rpc = new Rpc(transport).halfClose(); rpc.waitForResponse(); @@ -766,11 +880,13 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception { .trustManager(caCert) .keyManager(clientCert, clientKey) .build(); - ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool); + ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool, + Optional.absent(), null, null); // after starting the client, the Executor in the client pool should be used assertEquals(true, clientExecutorPool.isInUse()); final NettyClientTransport transport = newTransport(negotiator); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); Rpc rpc = new Rpc(transport).halfClose(); rpc.waitForResponse(); // closing the negotiators should return the executors back to pool, and release the resource @@ -780,6 +896,179 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception { assertEquals(false, serverExecutorPool.isInUse()); } + /** + * This test tests the case of TlsCredentials passed to ProtocolNegotiators not having an instance + * of X509ExtendedTrustManager (this is not testable in ProtocolNegotiatorsTest without creating + * accessors for the internal state of negotiator whether it has a X509ExtendedTrustManager, + * hence the need to test it in this class instead). To establish a successful handshake we create + * a fake X509TrustManager not implementing X509ExtendedTrustManager but wraps the real + * X509ExtendedTrustManager. + */ + @Test + public void authorityOverrideInCallOptions_noX509ExtendedTrustManager_newStreamCreationFails() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + InputStream caCert = TlsTesting.loadCert("ca.pem"); + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert); + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)).build()); + NettyClientTransport transport = newTransport(result.negotiator.newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + Rpc rpc = new Rpc(transport, new Metadata(), "foo.test.google.in"); + try { + rpc.waitForClose(); + fail("Expected exception in starting stream"); + } catch (ExecutionException ex) { + Status status = ((StatusException) ex.getCause()).getStatus(); + assertThat(status.getDescription()).isEqualTo("Can't allow authority override in rpc " + + "when X509ExtendedTrustManager is not available"); + assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE); + } + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void authorityOverrideInCallOptions_doesntMatchServerPeerHost_newStreamCreationFails() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + Rpc rpc = new Rpc(transport, new Metadata(), "foo.test.google.in"); + try { + rpc.waitForClose(); + fail("Expected exception in starting stream"); + } catch (ExecutionException ex) { + Status status = ((StatusException) ex.getCause()).getStatus(); + assertThat(status.getDescription()).isEqualTo("Peer hostname verification during rpc " + + "failed for authority 'foo.test.google.in'"); + assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException()) + .isInstanceOf(CertificateException.class); + assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException() + .getMessage()).isEqualTo( + "No subject alternative DNS name matching foo.test.google.in found."); + } + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void authorityOverrideInCallOptions_matchesServerPeerHost_newStreamCreationSucceeds() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + new Rpc(transport, new Metadata(), "foo.test.google.fr").waitForResponse(); + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + // Without removing the port number part that {@link X509AuthorityVerifier} does, there will be a + // java.security.cert.CertificateException: Illegal given domain name: foo.test.google.fr:12345 + @Test + public void authorityOverrideInCallOptions_portNumberInAuthority_isStrippedForPeerVerification() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + new Rpc(transport, new Metadata(), "foo.test.google.fr:12345").waitForResponse(); + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void authorityOverrideInCallOptions_portNumberAndIpv6_isStrippedForPeerVerification() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + new Rpc(transport, new Metadata(), "[2001:db8:3333:4444:5555:6666:1.2.3.4]:12345") + .waitForResponse(); + } catch (ExecutionException ex) { + Status status = ((StatusException) ex.getCause()).getStatus(); + assertThat(status.getDescription()).isEqualTo("Peer hostname verification during rpc " + + "failed for authority '[2001:db8:3333:4444:5555:6666:1.2.3.4]:12345'"); + assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException()) + .isInstanceOf(CertificateException.class); + // Port number is removed by {@link X509AuthorityVerifier}. + assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException() + .getMessage()).isEqualTo( + "No subject alternative names matching IP address 2001:db8:3333:4444:5555:6666:1.2.3.4 " + + "found"); + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void authorityOverrideInCallOptions_notMatches_flagDisabled_createsStream() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + new Rpc(transport, new Metadata(), "foo.test.google.in").waitForResponse(); + } + private Throwable getRootCause(Throwable t) { if (t.getCause() == null) { return t; @@ -787,10 +1076,37 @@ private Throwable getRootCause(Throwable t) { return getRootCause(t.getCause()); } - private ProtocolNegotiator newNegotiator() throws IOException { + private ProtocolNegotiator newNegotiator() throws IOException, GeneralSecurityException { InputStream caCert = TlsTesting.loadCert("ca.pem"); SslContext clientContext = GrpcSslContexts.forClient().trustManager(caCert).build(); - return ProtocolNegotiators.tls(clientContext); + return ProtocolNegotiators.tls(clientContext, + (X509TrustManager) getX509ExtendedTrustManager(TlsTesting.loadCert("ca.pem"))); + } + + private static TrustManager getX509ExtendedTrustManager(InputStream rootCerts) + throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts); + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + for (TrustManager trustManager : trustManagerFactory.getTrustManagers()) { + if (trustManager instanceof X509ExtendedTrustManager) { + return trustManager; + } + } + return null; } private NettyClientTransport newTransport(ProtocolNegotiator negotiator) { @@ -813,11 +1129,29 @@ private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int max keepAliveTimeNano = KEEPALIVE_TIME_NANOS_DISABLED; } NettyClientTransport transport = new NettyClientTransport( - address, channelFactory, new HashMap, Object>(), group, - negotiator, false, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, - keepAliveTimeNano, keepAliveTimeoutNano, - false, authority, userAgent, tooManyPingsRunnable, - new TransportTracer(), eagAttributes, new SocketPicker(), new FakeChannelLogger(), false, + address, + channelFactory, + new HashMap, Object>(), + group, + negotiator, + false, + DEFAULT_WINDOW_SIZE, + maxMsgSize, + maxHeaderListSize, + maxHeaderListSize, + keepAliveTimeNano, + keepAliveTimeoutNano, + false, + authority, + userAgent, + tooManyPingsRunnable, + new TransportTracer(), + eagAttributes, + new SocketPicker(), + new FakeChannelLogger(), + false, + new MetricRecorder() { + }, Ticker.systemTicker()); transports.add(transport); return transport; @@ -827,23 +1161,45 @@ private void startServer() throws IOException { startServer(100, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); } + private void startServer(ServerListener serverListener) throws IOException { + startServer(100, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, serverListener); + } + private void startServer(int maxStreamsPerConnection, int maxHeaderListSize) throws IOException { + startServer(maxStreamsPerConnection, maxHeaderListSize, serverListener); + } + + private void startServer(int maxStreamsPerConnection, int maxHeaderListSize, + ServerListener serverListener) throws IOException { server = new NettyServer( TestUtils.testServerAddresses(new InetSocketAddress(0)), new ReflectiveChannelFactory<>(NioServerSocketChannel.class), new HashMap, Object>(), new HashMap, Object>(), - new FixedObjectPool<>(group), new FixedObjectPool<>(group), false, negotiator, + new FixedObjectPool<>(group), + new FixedObjectPool<>(group), + false, + negotiator, Collections.emptyList(), TransportTracer.getDefaultFactory(), maxStreamsPerConnection, false, - DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, maxHeaderListSize, - DEFAULT_SERVER_KEEPALIVE_TIME_NANOS, DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS, + DEFAULT_WINDOW_SIZE, + DEFAULT_MAX_MESSAGE_SIZE, + maxHeaderListSize, + maxHeaderListSize, + DEFAULT_SERVER_KEEPALIVE_TIME_NANOS, + DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS, MAX_CONNECTION_IDLE_NANOS_DISABLED, - MAX_CONNECTION_AGE_NANOS_DISABLED, MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE, true, 0, - MAX_RST_COUNT_DISABLED, 0, Attributes.EMPTY, - channelz); + MAX_CONNECTION_AGE_NANOS_DISABLED, + MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE, + true, + 0, + MAX_RST_COUNT_DISABLED, + 0, + Attributes.EMPTY, + channelz, + new MetricRecorder() {}); server.start(serverListener); address = TestUtils.testServerAddress((InetSocketAddress) server.getListenSocketAddress()); authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort()); @@ -879,13 +1235,20 @@ private static class Rpc { final TestClientStreamListener listener = new TestClientStreamListener(); Rpc(NettyClientTransport transport) { - this(transport, new Metadata()); + this(transport, new Metadata(), null); } Rpc(NettyClientTransport transport, Metadata headers) { + this(transport, headers, null); + } + + Rpc(NettyClientTransport transport, Metadata headers, String authorityOverride) { stream = transport.newStream( METHOD, headers, CallOptions.DEFAULT, new ClientStreamTracer[]{ new ClientStreamTracer() {} }); + if (authorityOverride != null) { + stream.setAuthority(authorityOverride); + } stream.start(listener); stream.request(1); stream.writeMessage(new ByteArrayInputStream(MESSAGE.getBytes(UTF_8))); @@ -975,6 +1338,15 @@ private final class EchoServerListener implements ServerListener { final List transports = new ArrayList<>(); final List streamListeners = Collections.synchronizedList(new ArrayList()); + Metadata responseHeaders; + + public EchoServerListener() { + this(new Metadata()); + } + + public EchoServerListener(Metadata responseHeaders) { + this.responseHeaders = responseHeaders; + } @Override public ServerTransportListener transportCreated(final ServerTransport transport) { @@ -984,7 +1356,7 @@ public ServerTransportListener transportCreated(final ServerTransport transport) public void streamCreated(ServerStream stream, String method, Metadata headers) { EchoServerStreamListener listener = new EchoServerStreamListener(stream, headers); stream.setListener(listener); - stream.writeHeaders(new Metadata(), true); + stream.writeHeaders(responseHeaders, true); stream.request(1); streamListeners.add(listener); } @@ -1031,9 +1403,15 @@ public NoopHandler(GrpcHttp2ConnectionHandler grpcHandler) { this.grpcHandler = grpcHandler; } + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + // Prevent any data being passed to NettyClientHandler + ReferenceCountUtil.release(msg); + } + @Override public void channelRegistered(ChannelHandlerContext ctx) throws Exception { - ctx.pipeline().addBefore(ctx.name(), null, grpcHandler); + ctx.pipeline().addAfter(ctx.name(), null, grpcHandler); } public void fail(ChannelHandlerContext ctx, Throwable cause) { @@ -1077,4 +1455,62 @@ public void log(ChannelLogLevel level, String message) {} @Override public void log(ChannelLogLevel level, String messageFormat, Object... args) {} } + + static class FakeClientTransportListener implements ManagedClientTransport.Listener { + private final SettableFuture connected; + + @GuardedBy("this") + private boolean isConnected = false; + + public FakeClientTransportListener(SettableFuture connected) { + this.connected = connected; + } + + @Override + public void transportShutdown(Status s, DisconnectError e) {} + + @Override + public void transportTerminated() {} + + @Override + public void transportReady() { + synchronized (this) { + isConnected = true; + } + connected.set(null); + } + + synchronized boolean isConnected() { + return isConnected; + } + + @Override + public void transportInUse(boolean inUse) {} + } + + private static class FakeTrustManager implements X509TrustManager { + + private final X509TrustManager delegate; + + public FakeTrustManager(X509TrustManager x509ExtendedTrustManager) { + this.delegate = x509ExtendedTrustManager; + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkClientTrusted(x509Certificates, s); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkServerTrusted(x509Certificates, s); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return delegate.getAcceptedIssuers(); + } + } } diff --git a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java index eef8d30e05a..c971294fbb6 100644 --- a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java @@ -38,7 +38,6 @@ import io.grpc.internal.WritableBuffer; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufUtil; import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.netty.buffer.UnpooledByteBufAllocator; @@ -68,6 +67,7 @@ import java.nio.ByteBuffer; import java.util.concurrent.Delayed; import java.util.concurrent.TimeUnit; +import org.junit.After; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -84,7 +84,6 @@ public abstract class NettyHandlerTestBase { protected static final int STREAM_ID = 3; - private ByteBuf content; private EmbeddedChannel channel; @@ -106,18 +105,24 @@ protected void manualSetUp() throws Exception {} protected final TransportTracer transportTracer = new TransportTracer(); protected int flowControlWindow = DEFAULT_WINDOW_SIZE; protected boolean autoFlowControl = false; - private final FakeClock fakeClock = new FakeClock(); FakeClock fakeClock() { return fakeClock; } + @After + public void tearDown() throws Exception { + if (channel() != null) { + channel().releaseInbound(); + channel().releaseOutbound(); + } + } + /** * Must be called by subclasses to initialize the handler and channel. */ protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception { - content = Unpooled.copiedBuffer("hello world", UTF_8); frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter())); frameReader = new DefaultHttp2FrameReader(headersDecoder); @@ -233,11 +238,11 @@ protected final Http2FrameReader frameReader() { } protected final ByteBuf content() { - return content; + return Unpooled.copiedBuffer(contentAsArray()); } protected final byte[] contentAsArray() { - return ByteBufUtil.getBytes(content()); + return "\000\000\000\000\rhello world".getBytes(UTF_8); } protected final Http2FrameWriter verifyWrite() { @@ -252,8 +257,8 @@ protected final void channelRead(Object obj) throws Exception { channel.writeInbound(obj); } - protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) { - final ByteBuf compressionFrame = Unpooled.buffer(content.length); + protected ByteBuf grpcFrame(byte[] message) { + final ByteBuf compressionFrame = Unpooled.buffer(message.length); MessageFramer framer = new MessageFramer( new MessageFramer.Sink() { @Override @@ -262,23 +267,22 @@ public void deliverFrame( if (frame != null) { ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf(); compressionFrame.writeBytes(bytebuf); + bytebuf.release(); } } }, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT), StatsTraceContext.NOOP); - framer.writePayload(new ByteArrayInputStream(content)); - framer.flush(); - ChannelHandlerContext ctx = newMockContext(); - new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream, - newPromise()); - return captureWrite(ctx); + framer.writePayload(new ByteArrayInputStream(message)); + framer.close(); + return compressionFrame; } - protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) { - // Need to retain the content since the frameWriter releases it. - content.retain(); + protected final ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) { + return dataFrame(streamId, endStream, grpcFrame(content)); + } + protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) { ChannelHandlerContext ctx = newMockContext(); new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise()); return captureWrite(ctx); @@ -410,6 +414,7 @@ public void dataSizeSincePingAccumulates() throws Exception { channelRead(dataFrame(3, false, buff.copy())); assertEquals(length * 3, handler.flowControlPing().getDataSincePing()); + buff.release(); } @Test @@ -608,12 +613,14 @@ public void bdpPingWindowResizing() throws Exception { private void readPingAck(long pingData) throws Exception { channelRead(pingFrame(true, pingData)); + channel().releaseOutbound(); } private void readXCopies(int copies, byte[] data) throws Exception { for (int i = 0; i < copies; i++) { channelRead(grpcDataFrame(STREAM_ID, false, data)); // buffer it stream().request(1); // consume it + channel().releaseOutbound(); } } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerBuilderTest.java b/netty/src/test/java/io/grpc/netty/NettyServerBuilderTest.java index 6d8192322aa..f3b73a515b5 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerBuilderTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerBuilderTest.java @@ -16,20 +16,19 @@ package io.grpc.netty; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; -import com.google.common.truth.Truth; -import io.grpc.ServerStreamTracer; +import io.grpc.MetricRecorder; import io.netty.channel.EventLoopGroup; import io.netty.channel.local.LocalServerChannel; import io.netty.handler.ssl.SslContext; import java.net.InetSocketAddress; import java.util.concurrent.TimeUnit; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -39,18 +38,16 @@ @RunWith(JUnit4.class) public class NettyServerBuilderTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); - private NettyServerBuilder builder = NettyServerBuilder.forPort(8080); @Test public void addMultipleListenAddresses() { builder.addListenAddress(new InetSocketAddress(8081)); - NettyServer server = - builder.buildTransportServers(ImmutableList.of()); + NettyServer server = builder.buildTransportServers( + ImmutableList.of(), + new MetricRecorder() {}); - Truth.assertThat(server.getListenSocketAddresses()).hasSize(2); + assertThat(server.getListenSocketAddresses()).hasSize(2); } @Test @@ -63,105 +60,112 @@ public void failIfSslContextIsNotServer() { SslContext sslContext = mock(SslContext.class); when(sslContext.isClient()).thenReturn(true); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Client SSL context can not be used for server"); - builder.sslContext(sslContext); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, () -> builder.sslContext(sslContext)); + assertThat(e).hasMessageThat().isEqualTo("Client SSL context can not be used for server"); } @Test public void failIfKeepAliveTimeNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("keepalive time must be positive"); - - builder.keepAliveTime(-10L, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.keepAliveTime(-10L, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("keepalive time must be positive:-10"); } @Test public void failIfKeepAliveTimeoutNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("keepalive timeout must be positive"); - - builder.keepAliveTimeout(-10L, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.keepAliveTimeout(-10L, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("keepalive timeout must be positive: -10"); } @Test public void failIfMaxConcurrentCallsPerConnectionNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("max must be positive"); - - builder.maxConcurrentCallsPerConnection(0); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxConcurrentCallsPerConnection(0)); + assertThat(e).hasMessageThat().isEqualTo("max must be positive: 0"); } @Test public void failIfMaxInboundMetadataSizeNonPositive() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("maxInboundMetadataSize must be positive"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxInboundMetadataSize(0)); + assertThat(e).hasMessageThat().isEqualTo("maxInboundMetadataSize must be positive: 0"); + } - builder.maxInboundMetadataSize(0); + @Test + public void failIfSoftInboundMetadataSizeNonPositive() { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxInboundMetadataSize(0, 100)); + assertThat(e).hasMessageThat().isEqualTo("softLimitHeaderListSize must be positive: 0"); } @Test - public void failIfMaxConnectionIdleNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("max connection idle must be positive"); + public void failIfMaxInboundMetadataSizeSmallerThanSoft() { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxInboundMetadataSize(100, 80)); + assertThat(e).hasMessageThat().isEqualTo("maxInboundMetadataSize: 80 " + + "must be greater than softLimitHeaderListSize: 100"); + } - builder.maxConnectionIdle(-1, TimeUnit.HOURS); + @Test + public void failIfMaxConnectionIdleNegative() { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxConnectionIdle(-1, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("max connection idle must be positive: -1"); } @Test public void failIfMaxConnectionAgeNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("max connection age must be positive"); - - builder.maxConnectionAge(-1, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxConnectionAge(-1, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("max connection age must be positive: -1"); } @Test public void failIfMaxConnectionAgeGraceNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("max connection age grace must be non-negative"); - - builder.maxConnectionAgeGrace(-1, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxConnectionAgeGrace(-1, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("max connection age grace must be non-negative: -1"); } @Test public void failIfPermitKeepAliveTimeNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("permit keepalive time must be non-negative"); - - builder.permitKeepAliveTime(-1, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.permitKeepAliveTime(-1, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("permit keepalive time must be non-negative: -1"); } @Test public void assertEventLoopsAndChannelType_onlyBossGroupProvided() { EventLoopGroup mockEventLoopGroup = mock(EventLoopGroup.class); builder.bossEventLoopGroup(mockEventLoopGroup); - thrown.expect(IllegalStateException.class); - thrown.expectMessage( - "All of BossEventLoopGroup, WorkerEventLoopGroup and ChannelType should be provided"); - - builder.assertEventLoopsAndChannelType(); + IllegalStateException e = assertThrows(IllegalStateException.class, + builder::assertEventLoopsAndChannelType); + assertThat(e).hasMessageThat().isEqualTo( + "All of BossEventLoopGroup, WorkerEventLoopGroup and ChannelType should be provided " + + "or neither should be"); } @Test public void assertEventLoopsAndChannelType_onlyWorkerGroupProvided() { EventLoopGroup mockEventLoopGroup = mock(EventLoopGroup.class); builder.workerEventLoopGroup(mockEventLoopGroup); - thrown.expect(IllegalStateException.class); - thrown.expectMessage( - "All of BossEventLoopGroup, WorkerEventLoopGroup and ChannelType should be provided"); - - builder.assertEventLoopsAndChannelType(); + IllegalStateException e = assertThrows(IllegalStateException.class, + builder::assertEventLoopsAndChannelType); + assertThat(e).hasMessageThat().isEqualTo( + "All of BossEventLoopGroup, WorkerEventLoopGroup and ChannelType should be provided " + + "or neither should be"); } @Test public void assertEventLoopsAndChannelType_onlyTypeProvided() { builder.channelType(LocalServerChannel.class); - thrown.expect(IllegalStateException.class); - thrown.expectMessage( - "All of BossEventLoopGroup, WorkerEventLoopGroup and ChannelType should be provided"); - - builder.assertEventLoopsAndChannelType(); + IllegalStateException e = assertThrows(IllegalStateException.class, + builder::assertEventLoopsAndChannelType); + assertThat(e).hasMessageThat().isEqualTo( + "All of BossEventLoopGroup, WorkerEventLoopGroup and ChannelType should be provided " + + "or neither should be"); } @Test @@ -186,4 +190,5 @@ public void useNioTransport_shouldNotThrow() { builder.assertEventLoopsAndChannelType(); } + } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 541490847c0..1c8d2b5479d 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -43,6 +43,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; @@ -58,6 +59,7 @@ import io.grpc.Attributes; import io.grpc.InternalStatus; import io.grpc.Metadata; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.Status.Code; @@ -74,10 +76,10 @@ import io.grpc.internal.testing.TestServerStreamTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http2.DefaultHttp2Headers; import io.netty.handler.codec.http2.Http2CodecUtil; import io.netty.handler.codec.http2.Http2Error; @@ -120,27 +122,22 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase streamListenerMessageQueue = new LinkedList<>(); private int maxConcurrentStreams = Integer.MAX_VALUE; private int maxHeaderListSize = Integer.MAX_VALUE; + private int softLimitHeaderListSize = Integer.MAX_VALUE; private boolean permitKeepAliveWithoutCalls = true; private long permitKeepAliveTimeInNanos = 0; private long maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED; @@ -207,6 +204,19 @@ protected void manualSetUp() throws Exception { // Simulate receipt of initial remote settings. ByteBuf serializedSettings = serializeSettings(new Http2Settings()); channelRead(serializedSettings); + channel().releaseOutbound(); + } + + @Test + public void tcpMetrics_recorded() throws Exception { + manualSetUp(); + handler().channelActive(ctx()); + // Verify that channelActive triggered TcpMetrics + verify(metricRecorder, atLeastOnce()).addLongCounter( + eq(io.grpc.InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT), + eq(1L), + any(), + any()); } @Test @@ -228,10 +238,11 @@ public void sendFrameShouldSucceed() throws Exception { createStream(); // Send a frame and verify that it was written. + ByteBuf content = content(); ChannelFuture future = enqueue( - new SendGrpcFrameCommand(stream.transportState(), content(), false)); + new SendGrpcFrameCommand(stream.transportState(), content, false)); assertTrue(future.isSuccess()); - verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(false), + verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(false), any(ChannelPromise.class)); } @@ -266,10 +277,11 @@ private void inboundDataShouldForwardToStreamListener(boolean endStream) throws // Create a data frame and then trigger the handler to read it. ByteBuf frame = grpcDataFrame(STREAM_ID, endStream, contentAsArray()); channelRead(frame); + channel().releaseOutbound(); verify(streamListener, atLeastOnce()) .messagesAvailable(any(StreamListener.MessageProducer.class)); InputStream message = streamListenerMessageQueue.poll(); - assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(message)); + assertArrayEquals(contentAsArray(), ByteStreams.toByteArray(message)); message.close(); assertNull("no additional message expected", streamListenerMessageQueue.poll()); @@ -545,7 +557,8 @@ public void headersWithInvalidMethodShouldFail() throws Exception { .set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.INTERNAL.value())) .set(InternalStatus.MESSAGE_KEY.name(), "Method 'FAKE' is not supported") .status("" + 405) - .set(CONTENT_TYPE_HEADER, "text/plain; charset=utf-8"); + .set(CONTENT_TYPE_HEADER, "text/plain; charset=utf-8") + .set(HttpHeaderNames.ALLOW, HTTP_METHOD); verifyWrite() .writeHeaders( @@ -869,7 +882,7 @@ public void keepAliveEnforcer_sendingDataResetsCounters() throws Exception { future.get(); for (int i = 0; i < 10; i++) { future = enqueue( - new SendGrpcFrameCommand(stream.transportState(), content().retainedSlice(), false)); + new SendGrpcFrameCommand(stream.transportState(), content(), false)); future.get(); channel().releaseOutbound(); channelRead(pingFrame(false /* isAck */, 1L)); @@ -1292,6 +1305,7 @@ public void maxRstCount_withinLimit_succeeds() throws Exception { maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); manualSetUp(); rapidReset(maxRstCount); + assertTrue(channel().isOpen()); } @@ -1301,10 +1315,13 @@ public void maxRstCount_exceedsLimit_fails() throws Exception { maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); manualSetUp(); assertThrows(ClosedChannelException.class, () -> rapidReset(maxRstCount + 1)); + assertFalse(channel().isOpen()); } private void rapidReset(int burstSize) throws Exception { + when(streamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class))) + .thenAnswer((args) -> new TestServerStreamTracer()); Http2Headers headers = new DefaultHttp2Headers() .method(HTTP_METHOD) .set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8)) @@ -1324,6 +1341,48 @@ private void rapidReset(int burstSize) throws Exception { } } + @Test + public void maxRstCountSent_withinLimit_succeeds() throws Exception { + maxRstCount = 10; + maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); + manualSetUp(); + madeYouReset(maxRstCount); + + assertTrue(channel().isOpen()); + } + + @Test + public void maxRstCountSent_exceedsLimit_fails() throws Exception { + maxRstCount = 10; + maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); + manualSetUp(); + assertThrows(ClosedChannelException.class, () -> madeYouReset(maxRstCount + 1)); + + assertFalse(channel().isOpen()); + } + + private void madeYouReset(int burstSize) throws Exception { + when(streamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class))) + .thenAnswer((args) -> new TestServerStreamTracer()); + Http2Headers headers = new DefaultHttp2Headers() + .method(HTTP_METHOD) + .set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8)) + .set(TE_HEADER, TE_TRAILERS) + .path(new AsciiString("/foo/bar")); + int streamId = 1; + long rpcTimeNanos = maxRstPeriodNanos / 2 / burstSize; + for (int period = 0; period < 3; period++) { + for (int i = 0; i < burstSize; i++) { + channelRead(headersFrame(streamId, headers)); + channelRead(windowUpdate(streamId, 0)); + streamId += 2; + fakeClock().forwardNanos(rpcTimeNanos); + } + while (channel().readOutbound() != null) {} + fakeClock().forwardNanos(maxRstPeriodNanos - rpcTimeNanos * burstSize + 1); + } + } + private void createStream() throws Exception { Http2Headers headers = new DefaultHttp2Headers() .method(HTTP_METHOD) @@ -1343,11 +1402,7 @@ private void createStream() throws Exception { private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception { ByteBuf buf = NettyTestUtil.messageFrame(""); - try { - return dataFrame(streamId, endStream, buf); - } finally { - buf.release(); - } + return dataFrame(streamId, endStream, buf); } @Override @@ -1363,6 +1418,7 @@ protected NettyServerHandler newHandler() { autoFlowControl, flowControlWindow, maxHeaderListSize, + softLimitHeaderListSize, DEFAULT_MAX_MESSAGE_SIZE, keepAliveTimeInNanos, keepAliveTimeoutInNanos, @@ -1374,7 +1430,8 @@ protected NettyServerHandler newHandler() { maxRstCount, maxRstPeriodNanos, Attributes.EMPTY, - fakeClock().getTicker()); + fakeClock().getTicker(), + metricRecorder); } @Override diff --git a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java index 452f68341b1..2f2933ae103 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java @@ -38,7 +38,9 @@ import com.google.common.base.Strings; import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.Iterables; import com.google.common.collect.ListMultimap; +import com.google.common.collect.Lists; import io.grpc.Attributes; import io.grpc.Metadata; import io.grpc.Status; @@ -57,7 +59,6 @@ import java.util.LinkedList; import java.util.List; import java.util.Queue; -import java.util.stream.Collectors; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -157,7 +158,7 @@ public void writeFrameFutureFailedShouldCancelRpc() { // Verify that failed SendGrpcFrameCommand results in immediate CancelServerStreamCommand. inOrder.verify(writeQueue).enqueue(any(CancelServerStreamCommand.class), eq(true)); // Verify that any other failures do not produce another CancelServerStreamCommand in the queue. - inOrder.verify(writeQueue, atLeast(1)).enqueue(any(SendGrpcFrameCommand.class), eq(false)); + inOrder.verify(writeQueue, atLeast(0)).enqueue(any(SendGrpcFrameCommand.class), eq(false)); inOrder.verify(writeQueue).enqueue(any(SendGrpcFrameCommand.class), eq(true)); inOrder.verifyNoMoreInteractions(); } @@ -217,14 +218,15 @@ private CancelServerStreamCommand findCancelServerStreamCommand() { // Ensure there's no CancelServerStreamCommand enqueued with flush=false. verify(writeQueue, never()).enqueue(any(CancelServerStreamCommand.class), eq(false)); - List commands = Mockito.mockingDetails(writeQueue).getInvocations() - .stream() - // Get enqueue() innovations only. - .filter(invocation -> invocation.getMethod().getName().equals("enqueue")) - // Find the cancel commands. - .filter(invocation -> invocation.getArgument(0) instanceof CancelServerStreamCommand) - .map(invocation -> invocation.getArgument(0, CancelServerStreamCommand.class)) - .collect(Collectors.toList()); + List commands = Lists.newArrayList( + Iterables.transform( + Iterables.filter( + Mockito.mockingDetails(writeQueue).getInvocations(), + // Get enqueue() innovations only + invocation -> invocation.getMethod().getName().equals("enqueue") + // Find the cancel commands. + && invocation.getArgument(0) instanceof CancelServerStreamCommand), + invocation -> invocation.getArgument(0, CancelServerStreamCommand.class))); assertWithMessage("Expected exactly one CancelClientStreamCommand").that(commands).hasSize(1); return commands.get(0); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerTest.java index 64d31070156..61c3f9e219e 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerTest.java @@ -37,6 +37,7 @@ import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalInstrumented; import io.grpc.Metadata; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.internal.FixedObjectPool; import io.grpc.internal.ServerListener; @@ -133,29 +134,35 @@ class NoHandlerProtocolNegotiator implements ProtocolNegotiator { } NoHandlerProtocolNegotiator protocolNegotiator = new NoHandlerProtocolNegotiator(); - NettyServer ns = new NettyServer( - Arrays.asList(addr), - new ReflectiveChannelFactory<>(NioServerSocketChannel.class), - new HashMap, Object>(), - new HashMap, Object>(), - new FixedObjectPool<>(eventLoop), - new FixedObjectPool<>(eventLoop), - false, - protocolNegotiator, - Collections.emptyList(), - TransportTracer.getDefaultFactory(), - 1, // ignore - false, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore - Attributes.EMPTY, - channelz); + NettyServer ns = + new NettyServer( + Arrays.asList(addr), + new ReflectiveChannelFactory<>(NioServerSocketChannel.class), + new HashMap, Object>(), + new HashMap, Object>(), + new FixedObjectPool<>(eventLoop), + new FixedObjectPool<>(eventLoop), + false, + protocolNegotiator, + Collections.emptyList(), + TransportTracer.getDefaultFactory(), + 1, // ignore + false, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore + Attributes.EMPTY, + channelz, mock(MetricRecorder.class)); final SettableFuture serverShutdownCalled = SettableFuture.create(); ns.start(new ServerListener() { @Override @@ -184,29 +191,35 @@ public void multiPortStartStopGet() throws Exception { InetSocketAddress addr1 = new InetSocketAddress(0); InetSocketAddress addr2 = new InetSocketAddress(0); - NettyServer ns = new NettyServer( - Arrays.asList(addr1, addr2), - new ReflectiveChannelFactory<>(NioServerSocketChannel.class), - new HashMap, Object>(), - new HashMap, Object>(), - new FixedObjectPool<>(eventLoop), - new FixedObjectPool<>(eventLoop), - false, - ProtocolNegotiators.plaintext(), - Collections.emptyList(), - TransportTracer.getDefaultFactory(), - 1, // ignore - false, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore - Attributes.EMPTY, - channelz); + NettyServer ns = + new NettyServer( + Arrays.asList(addr1, addr2), + new ReflectiveChannelFactory<>(NioServerSocketChannel.class), + new HashMap, Object>(), + new HashMap, Object>(), + new FixedObjectPool<>(eventLoop), + new FixedObjectPool<>(eventLoop), + false, + ProtocolNegotiators.plaintext(), + Collections.emptyList(), + TransportTracer.getDefaultFactory(), + 1, // ignore + false, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore + Attributes.EMPTY, + channelz, mock(MetricRecorder.class)); final SettableFuture shutdownCompleted = SettableFuture.create(); ns.start(new ServerListener() { @Override @@ -258,29 +271,35 @@ public void multiPortConnections() throws Exception { InetSocketAddress addr2 = new InetSocketAddress(0); final CountDownLatch allPortsConnectedCountDown = new CountDownLatch(2); - NettyServer ns = new NettyServer( - Arrays.asList(addr1, addr2), - new ReflectiveChannelFactory<>(NioServerSocketChannel.class), - new HashMap, Object>(), - new HashMap, Object>(), - new FixedObjectPool<>(eventLoop), - new FixedObjectPool<>(eventLoop), - false, - ProtocolNegotiators.plaintext(), - Collections.emptyList(), - TransportTracer.getDefaultFactory(), - 1, // ignore - false, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore - Attributes.EMPTY, - channelz); + NettyServer ns = + new NettyServer( + Arrays.asList(addr1, addr2), + new ReflectiveChannelFactory<>(NioServerSocketChannel.class), + new HashMap, Object>(), + new HashMap, Object>(), + new FixedObjectPool<>(eventLoop), + new FixedObjectPool<>(eventLoop), + false, + ProtocolNegotiators.plaintext(), + Collections.emptyList(), + TransportTracer.getDefaultFactory(), + 1, // ignore + false, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore + Attributes.EMPTY, + channelz, mock(MetricRecorder.class)); final SettableFuture shutdownCompleted = SettableFuture.create(); ns.start(new ServerListener() { @Override @@ -320,29 +339,35 @@ public void run() {} public void getPort_notStarted() { InetSocketAddress addr = new InetSocketAddress(0); List addresses = Collections.singletonList(addr); - NettyServer ns = new NettyServer( - addresses, - new ReflectiveChannelFactory<>(NioServerSocketChannel.class), - new HashMap, Object>(), - new HashMap, Object>(), - new FixedObjectPool<>(eventLoop), - new FixedObjectPool<>(eventLoop), - false, - ProtocolNegotiators.plaintext(), - Collections.emptyList(), - TransportTracer.getDefaultFactory(), - 1, // ignore - false, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore - Attributes.EMPTY, - channelz); + NettyServer ns = + new NettyServer( + addresses, + new ReflectiveChannelFactory<>(NioServerSocketChannel.class), + new HashMap, Object>(), + new HashMap, Object>(), + new FixedObjectPool<>(eventLoop), + new FixedObjectPool<>(eventLoop), + false, + ProtocolNegotiators.plaintext(), + Collections.emptyList(), + TransportTracer.getDefaultFactory(), + 1, // ignore + false, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore + Attributes.EMPTY, + channelz, mock(MetricRecorder.class)); assertThat(ns.getListenSocketAddress()).isEqualTo(addr); assertThat(ns.getListenSocketAddresses()).isEqualTo(addresses); @@ -395,29 +420,35 @@ class TestProtocolNegotiator implements ProtocolNegotiator { .build(); TestProtocolNegotiator protocolNegotiator = new TestProtocolNegotiator(); InetSocketAddress addr = new InetSocketAddress(0); - NettyServer ns = new NettyServer( - Arrays.asList(addr), - new ReflectiveChannelFactory<>(NioServerSocketChannel.class), - new HashMap, Object>(), - childChannelOptions, - new FixedObjectPool<>(eventLoop), - new FixedObjectPool<>(eventLoop), - false, - protocolNegotiator, - Collections.emptyList(), - TransportTracer.getDefaultFactory(), - 1, // ignore - false, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore - eagAttributes, - channelz); + NettyServer ns = + new NettyServer( + Arrays.asList(addr), + new ReflectiveChannelFactory<>(NioServerSocketChannel.class), + new HashMap, Object>(), + childChannelOptions, + new FixedObjectPool<>(eventLoop), + new FixedObjectPool<>(eventLoop), + false, + protocolNegotiator, + Collections.emptyList(), + TransportTracer.getDefaultFactory(), + 1, // ignore + false, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore + eagAttributes, + channelz, mock(MetricRecorder.class)); ns.start(new ServerListener() { @Override public ServerTransportListener transportCreated(ServerTransport transport) { @@ -443,29 +474,35 @@ public void serverShutdown() {} @Test public void channelzListenSocket() throws Exception { InetSocketAddress addr = new InetSocketAddress(0); - NettyServer ns = new NettyServer( - Arrays.asList(addr), - new ReflectiveChannelFactory<>(NioServerSocketChannel.class), - new HashMap, Object>(), - new HashMap, Object>(), - new FixedObjectPool<>(eventLoop), - new FixedObjectPool<>(eventLoop), - false, - ProtocolNegotiators.plaintext(), - Collections.emptyList(), - TransportTracer.getDefaultFactory(), - 1, // ignore - false, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore - Attributes.EMPTY, - channelz); + NettyServer ns = + new NettyServer( + Arrays.asList(addr), + new ReflectiveChannelFactory<>(NioServerSocketChannel.class), + new HashMap, Object>(), + new HashMap, Object>(), + new FixedObjectPool<>(eventLoop), + new FixedObjectPool<>(eventLoop), + false, + ProtocolNegotiators.plaintext(), + Collections.emptyList(), + TransportTracer.getDefaultFactory(), + 1, // ignore + false, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore + Attributes.EMPTY, + channelz, mock(MetricRecorder.class)); final SettableFuture shutdownCompleted = SettableFuture.create(); ns.start(new ServerListener() { @Override @@ -603,12 +640,17 @@ private NettyServer getServer(List addr, EventLoopGroup ev) { 1, // ignore 1, // ignore 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore Attributes.EMPTY, - channelz); + channelz, mock(MetricRecorder.class)); } private static class NoopServerTransportListener implements ServerTransportListener { diff --git a/netty/src/test/java/io/grpc/netty/NettyTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyTransportTest.java index b1c89e22f93..22758a8b727 100644 --- a/netty/src/test/java/io/grpc/netty/NettyTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyTransportTest.java @@ -22,10 +22,12 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.ChannelLogger; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.internal.AbstractTransportTest; import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.DisconnectError; import io.grpc.internal.FakeClock; import io.grpc.internal.InternalServer; import io.grpc.internal.ManagedClientTransport; @@ -70,7 +72,7 @@ protected InternalServer newServer( .forAddress(new InetSocketAddress("localhost", 0)) .flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW) .setTransportTracerFactory(fakeClockTransportTracer) - .buildTransportServers(streamTracerFactories); + .buildTransportServers(streamTracerFactories, new MetricRecorder() {}); } @Override @@ -80,7 +82,7 @@ protected InternalServer newServer( .forAddress(new InetSocketAddress("localhost", port)) .flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW) .setTransportTracerFactory(fakeClockTransportTracer) - .buildTransportServers(streamTracerFactories); + .buildTransportServers(streamTracerFactories, new MetricRecorder() {}); } @Override @@ -127,7 +129,7 @@ public void channelHasUnresolvedHostname() throws Exception { .setChannelLogger(logger), logger); Runnable runnable = transport.start(new ManagedClientTransport.Listener() { @Override - public void transportShutdown(Status s) { + public void transportShutdown(Status s, DisconnectError e) { future.set(s); } diff --git a/netty/src/test/java/io/grpc/netty/NettyWritableBufferAllocatorTest.java b/netty/src/test/java/io/grpc/netty/NettyWritableBufferAllocatorTest.java index d577ec46b03..0b741ae24b3 100644 --- a/netty/src/test/java/io/grpc/netty/NettyWritableBufferAllocatorTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyWritableBufferAllocatorTest.java @@ -40,13 +40,6 @@ protected WritableBufferAllocator allocator() { return allocator; } - @Test - public void testCapacityHasMinimum() { - WritableBuffer buffer = allocator().allocate(100); - assertEquals(0, buffer.readableBytes()); - assertEquals(4096, buffer.writableBytes()); - } - @Test public void testCapacityIsExactAboveMinimum() { WritableBuffer buffer = allocator().allocate(9000); diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 6939d835892..403b1b64329 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -24,6 +24,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -31,6 +32,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import com.google.common.base.Optional; import io.grpc.Attributes; import io.grpc.CallCredentials; import io.grpc.ChannelCredentials; @@ -44,6 +46,7 @@ import io.grpc.InternalChannelz; import io.grpc.InternalChannelz.Security; import io.grpc.Metadata; +import io.grpc.MetricRecorder; import io.grpc.SecurityLevel; import io.grpc.ServerCredentials; import io.grpc.ServerStreamTracer; @@ -53,6 +56,7 @@ import io.grpc.TlsChannelCredentials; import io.grpc.TlsServerCredentials; import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.DisconnectError; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.InternalServer; import io.grpc.internal.ManagedClientTransport; @@ -111,15 +115,20 @@ import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandshakeCompletionEvent; import java.io.File; +import java.io.IOException; import java.io.InputStream; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.ArrayDeque; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -141,7 +150,6 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.DisableOnDebug; -import org.junit.rules.ExpectedException; import org.junit.rules.TestRule; import org.junit.rules.Timeout; import org.junit.runner.RunWith; @@ -169,8 +177,6 @@ public static void loadCerts() throws Exception { private static final int TIMEOUT_SECONDS = 60; @Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(TIMEOUT_SECONDS)); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); private final EventLoopGroup group = new DefaultEventLoop(); private Channel chan; @@ -221,13 +227,52 @@ public ChannelCredentials withoutBearerTokens() { } @Test - public void fromClient_tls() { + public void fromClient_tls_trustManager() + throws KeyStoreException, CertificateException, IOException, NoSuchAlgorithmException { + KeyStore certStore = KeyStore.getInstance(KeyStore.getDefaultType()); + certStore.load(null); + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + try (InputStream ca = TlsTesting.loadCert("ca.pem")) { + for (X509Certificate cert : CertificateUtils.getX509Certificates(ca)) { + certStore.setCertificateEntry(cert.getSubjectX500Principal().getName("RFC2253"), cert); + } + } + trustManagerFactory.init(certStore); + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(TlsChannelCredentials.newBuilder() + .trustManager(trustManagerFactory.getTrustManagers()).build()); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); + assertThat(((ClientTlsProtocolNegotiator) result.negotiator.newNegotiator()) + .hasX509ExtendedTrustManager()).isTrue(); + } + + @Test + public void fromClient_tls_CaCertsInputStream() throws IOException { + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(TlsChannelCredentials.newBuilder() + .trustManager(TlsTesting.loadCert("ca.pem")).build()); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); + assertThat(((ClientTlsProtocolNegotiator) result.negotiator.newNegotiator()) + .hasX509ExtendedTrustManager()).isTrue(); + } + + @Test + public void fromClient_tls_systemDefault() { ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(TlsChannelCredentials.create()); assertThat(result.error).isNull(); assertThat(result.callCredentials).isNull(); assertThat(result.negotiator) .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); + assertThat(((ClientTlsProtocolNegotiator) result.negotiator.newNegotiator()) + .hasX509ExtendedTrustManager()).isTrue(); } @Test @@ -345,7 +390,9 @@ private Object expectHandshake( .buildTransportFactory(); InternalServer server = NettyServerBuilder .forPort(0, serverCreds) - .buildTransportServers(Collections.emptyList()); + .buildTransportServers( + Collections.emptyList(), + new MetricRecorder() {}); server.start(serverListener); ManagedClientTransport.Listener clientTransportListener = @@ -366,7 +413,7 @@ private Object expectHandshake( } else { ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); verify(clientTransportListener, timeout(TIMEOUT_SECONDS * 1000)) - .transportShutdown(captor.capture()); + .transportShutdown(captor.capture(), any(DisconnectError.class)); result = captor.getValue(); } @@ -670,11 +717,10 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { } @Test - public void tlsHandler_failsOnNullEngine() throws Exception { - thrown.expect(NullPointerException.class); - thrown.expectMessage("ssl"); - - Object unused = ProtocolNegotiators.serverTls(null); + public void tlsHandler_failsOnNullEngine() { + NullPointerException e = assertThrows(NullPointerException.class, + () -> ProtocolNegotiators.serverTls(null)); + assertThat(e).hasMessageThat().isEqualTo("sslContext"); } @@ -876,7 +922,8 @@ public String applicationProtocol() { DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", elg, noopLogger); + "authority", elg, noopLogger, Optional.absent(), + getClientTlsProtocolNegotiator(), null); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -914,7 +961,8 @@ public String applicationProtocol() { .applicationProtocolConfig(apn).build(); ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", elg, noopLogger); + "authority", elg, noopLogger, Optional.absent(), + getClientTlsProtocolNegotiator(), null); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -938,7 +986,8 @@ public String applicationProtocol() { DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", elg, noopLogger); + "authority", elg, noopLogger, Optional.absent(), + getClientTlsProtocolNegotiator(), null); pipeline.addLast(handler); final AtomicReference error = new AtomicReference<>(); @@ -966,7 +1015,8 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { @Test public void clientTlsHandler_closeDuringNegotiation() throws Exception { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", null, noopLogger); + "authority", null, noopLogger, Optional.absent(), + getClientTlsProtocolNegotiator(), null); pipeline.addLast(new WriteBufferingAndExceptionHandler(handler)); ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE); @@ -978,6 +1028,12 @@ public void clientTlsHandler_closeDuringNegotiation() throws Exception { .isEqualTo(Status.Code.UNAVAILABLE); } + private ClientTlsProtocolNegotiator getClientTlsProtocolNegotiator() throws SSLException { + return new ClientTlsProtocolNegotiator(GrpcSslContexts.forClient().trustManager( + TlsTesting.loadCert("ca.pem")).build(), + null, Optional.absent(), null, ""); + } + @Test public void engineLog() { ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); @@ -1004,9 +1060,8 @@ public boolean isLoggable(LogRecord record) { @Test public void tls_failsOnNullSslContext() { - thrown.expect(NullPointerException.class); - - Object unused = ProtocolNegotiators.tls(null); + assertThrows(NullPointerException.class, + () -> ProtocolNegotiators.tls(null, null)); } @Test @@ -1036,23 +1091,23 @@ public void tls_invalidHost() throws SSLException { } @Test - public void httpProxy_nullAddressNpe() throws Exception { - thrown.expect(NullPointerException.class); - Object unused = - ProtocolNegotiators.httpProxy(null, "user", "pass", ProtocolNegotiators.plaintext()); + public void httpProxy_nullAddressNpe() { + assertThrows(NullPointerException.class, + () -> ProtocolNegotiators.httpProxy(null, null, "user", "pass", + ProtocolNegotiators.plaintext())); } @Test - public void httpProxy_nullNegotiatorNpe() throws Exception { - thrown.expect(NullPointerException.class); - Object unused = ProtocolNegotiators.httpProxy( - InetSocketAddress.createUnresolved("localhost", 80), "user", "pass", null); + public void httpProxy_nullNegotiatorNpe() { + assertThrows(NullPointerException.class, + () -> ProtocolNegotiators.httpProxy( + InetSocketAddress.createUnresolved("localhost", 80), null, "user", "pass", null)); } @Test public void httpProxy_nullUserPassNoException() throws Exception { assertNotNull(ProtocolNegotiators.httpProxy( - InetSocketAddress.createUnresolved("localhost", 80), null, null, + InetSocketAddress.createUnresolved("localhost", 80), null, null, null, ProtocolNegotiators.plaintext())); } @@ -1070,7 +1125,7 @@ public void httpProxy_completes() throws Exception { .bind(proxy).sync().channel(); ProtocolNegotiator nego = - ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext()); + ProtocolNegotiators.httpProxy(proxy, null, null, null, ProtocolNegotiators.plaintext()); // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, // mocking the behavior using KickStartHandler. ChannelHandler handler = @@ -1133,7 +1188,7 @@ public void httpProxy_500() throws Exception { .bind(proxy).sync().channel(); ProtocolNegotiator nego = - ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext()); + ProtocolNegotiators.httpProxy(proxy, null, null, null, ProtocolNegotiators.plaintext()); // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, // mocking the behavior using KickStartHandler. ChannelHandler handler = @@ -1164,14 +1219,84 @@ public void httpProxy_500() throws Exception { assertFalse(negotiationFuture.isDone()); String response = "HTTP/1.1 500 OMG\r\nContent-Length: 4\r\n\r\noops"; serverContext.writeAndFlush(bb(response, serverContext.channel())).sync(); - thrown.expect(ProxyConnectException.class); try { - negotiationFuture.sync(); + assertThrows(ProxyConnectException.class, negotiationFuture::sync); } finally { channel.close(); } } + @Test + public void httpProxy_customHeaders() throws Exception { + DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); + // ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called + // the channel is already active. + LocalAddress proxy = new LocalAddress("httpProxy_customHeaders"); + SocketAddress host = InetSocketAddress.createUnresolved("example.com", 443); + + ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class); + Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class) + .childHandler(mockHandler) + .bind(proxy).sync().channel(); + + Map headers = new java.util.HashMap<>(); + headers.put("X-Custom-Header", "custom-value"); + headers.put("Proxy-Authorization", "Bearer token123"); + + ProtocolNegotiator nego = ProtocolNegotiators.httpProxy( + proxy, headers, null, null, ProtocolNegotiators.plaintext()); + // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, + // mocking the behavior using KickStartHandler. + ChannelHandler handler = + new KickStartHandler(nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler())); + Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler) + .register().sync().channel(); + pipeline = channel.pipeline(); + // Wait for initialization to complete + channel.eventLoop().submit(NOOP_RUNNABLE).sync(); + channel.connect(host).sync(); + serverChannel.close(); + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(ChannelHandlerContext.class); + Mockito.verify(mockHandler).channelActive(contextCaptor.capture()); + ChannelHandlerContext serverContext = contextCaptor.getValue(); + + final String golden = "testData"; + ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel)); + + // Wait for sending initial request to complete + channel.eventLoop().submit(NOOP_RUNNABLE).sync(); + ArgumentCaptor objectCaptor = ArgumentCaptor.forClass(Object.class); + Mockito.verify(mockHandler) + .channelRead(ArgumentMatchers.any(), objectCaptor.capture()); + ByteBuf b = (ByteBuf) objectCaptor.getValue(); + String request = b.toString(UTF_8); + b.release(); + + // Verify custom headers are present in the CONNECT request + assertTrue("No trailing newline: " + request, request.endsWith("\r\n\r\n")); + assertTrue("No CONNECT: " + request, request.startsWith("CONNECT example.com:443 ")); + assertTrue("No custom header: " + request, + request.contains("X-Custom-Header: custom-value")); + assertTrue("No proxy authorization: " + request, + request.contains("Proxy-Authorization: Bearer token123")); + + assertFalse(negotiationFuture.isDone()); + serverContext.writeAndFlush(bb("HTTP/1.1 200 OK\r\n\r\n", serverContext.channel())).sync(); + negotiationFuture.sync(); + + channel.eventLoop().submit(NOOP_RUNNABLE).sync(); + objectCaptor = ArgumentCaptor.forClass(Object.class); + Mockito.verify(mockHandler, times(2)) + .channelRead(ArgumentMatchers.any(), objectCaptor.capture()); + b = (ByteBuf) objectCaptor.getAllValues().get(1); + String preface = b.toString(UTF_8); + b.release(); + assertEquals(golden, preface); + + channel.close(); + } + @Test public void waitUntilActiveHandler_firesNegotiation() throws Exception { EventLoopGroup elg = new DefaultEventLoopGroup(1); @@ -1228,7 +1353,8 @@ public void clientTlsHandler_firesNegotiation() throws Exception { serverSslContext = GrpcSslContexts.forServer(server1Chain, server1Key).build(); } FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler(); - ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, null); + ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, + null, Optional.absent(), null, null); WriteBufferingAndExceptionHandler clientWbaeh = new WriteBufferingAndExceptionHandler(pn.newHandler(gh)); diff --git a/netty/src/test/java/io/grpc/netty/TcpMetricsTest.java b/netty/src/test/java/io/grpc/netty/TcpMetricsTest.java new file mode 100644 index 00000000000..f75a98b46df --- /dev/null +++ b/netty/src/test/java/io/grpc/netty/TcpMetricsTest.java @@ -0,0 +1,616 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import io.grpc.InternalTcpMetrics; +import io.grpc.MetricRecorder; +import io.netty.util.concurrent.ScheduledFuture; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class TcpMetricsTest { + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private MetricRecorder metricRecorder; + + private ConfigurableFakeWithTcpInfo channel; + private TcpMetrics metrics; + + @Before + public void setUp() throws Exception { + FakeEpollTcpInfo dummyInfo = new FakeEpollTcpInfo(); + channel = new ConfigurableFakeWithTcpInfo(dummyInfo); + metrics = new TcpMetrics(metricRecorder); + } + + @After + public void tearDown() throws Exception { + TcpMetrics.epollInfo = TcpMetrics.loadEpollInfo(); + } + + @Test + public void metricsInitialization() { + + assertNotNull(InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT); + assertNotNull(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT); + assertNotNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT); + assertNotNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT); + assertNotNull(InternalTcpMetrics.MIN_RTT_INSTRUMENT); + } + + public static class FakeEpollTcpInfo { + long totalRetrans; + long retransmits; + long rtt; + + public void setValues(long totalRetrans, long retransmits, long rtt) { + this.totalRetrans = totalRetrans; + this.retransmits = retransmits; + this.rtt = rtt; + } + + @SuppressWarnings("unused") + public long totalRetrans() { + return totalRetrans; + } + + @SuppressWarnings("unused") + public long retrans() { + return retransmits; + } + + @SuppressWarnings("unused") + public long rtt() { + return rtt; + } + } + + @Test + public void tracker_recordTcpInfo_reflectionSuccess() throws Exception { + MetricRecorder recorder = mock(MetricRecorder.class); + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + TcpMetrics tracker = new TcpMetrics(recorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + infoSource.setValues(123, 4, 5000); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + channel.writeInbound("dummy"); + + tracker.channelInactive(channel); + + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(123L), any(), any()); + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT)), + eq(4L), any(), any()); + verify(recorder).recordDoubleHistogram( + eq(Objects.requireNonNull(InternalTcpMetrics.MIN_RTT_INSTRUMENT)), + eq(0.005), any(), any()); + } + + @Test + public void tracker_periodicRecord_doesNotRecordRecurringRetransmits() throws Exception { + MetricRecorder recorder = mock(MetricRecorder.class); + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + TcpMetrics tracker = new TcpMetrics(recorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + infoSource.setValues(123, 4, 5000); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + tracker.channelActive(channel); + + ScheduledFuture timer = tracker.getReportTimer(); + assertNotNull("Timer should be scheduled", timer); + + long delay = timer.getDelay(TimeUnit.MILLISECONDS); + channel.advanceTimeBy(delay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(123L), any(), any()); + verify(recorder).recordDoubleHistogram( + eq(Objects.requireNonNull(InternalTcpMetrics.MIN_RTT_INSTRUMENT)), + eq(0.005), any(), any()); + // Should NOT record recurring retransmits during periodic polling + verify(recorder, org.mockito.Mockito.never()) + .addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT)), + anyLong(), any(), any()); + } + + @Test + public void tracker_channelInactive_recordsRecurringRetransmits_raw_notDelta() throws Exception { + MetricRecorder recorder = mock(MetricRecorder.class); + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + TcpMetrics tracker = new TcpMetrics(recorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + infoSource.setValues(123, 4, 5000); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + tracker.channelActive(channel); + + ScheduledFuture timer = tracker.getReportTimer(); + assertNotNull("Timer should be scheduled", timer); + + long delay = timer.getDelay(TimeUnit.MILLISECONDS); + channel.advanceTimeBy(delay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + org.mockito.Mockito.clearInvocations(recorder); + + // Let's just create a new channel instance where tcpInfo sets retrans=5. + FakeEpollTcpInfo infoSource2 = new FakeEpollTcpInfo(); + infoSource2.setValues(130, 5, 5000); + ConfigurableFakeWithTcpInfo channel2 = new ConfigurableFakeWithTcpInfo(infoSource2); + + tracker.channelInactive(channel2); + + // It should record delta for totalRetrans (130 - 123 = 7) + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(7L), any(), any()); + // But for recurringRetransmits it MUST record the raw value 5, not the delta! + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT)), + eq(5L), any(), any()); + } + + @Test + public void tracker_periodicRecord_reportsDeltaForTotalRetrans() throws Exception { + MetricRecorder recorder = mock(MetricRecorder.class); + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + TcpMetrics tracker = new TcpMetrics(recorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + infoSource.setValues(123, 4, 5000); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + tracker.channelActive(channel); + + ScheduledFuture timer = tracker.getReportTimer(); + assertNotNull("Timer should be scheduled", timer); + + long delay = timer.getDelay(TimeUnit.MILLISECONDS); + channel.advanceTimeBy(delay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(123L), any(), any()); + + org.mockito.Mockito.clearInvocations(recorder); + + // Change tcpInfo for second periodic record + infoSource.setValues(150, 2, 6000); // 150 - 123 = 27 + + ScheduledFuture newTimer = tracker.getReportTimer(); + assertNotNull("New timer should be scheduled", newTimer); + assertNotSame("Timer should be a new instance", timer, newTimer); + long newDelay = newTimer.getDelay(TimeUnit.MILLISECONDS); + channel.advanceTimeBy(newDelay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + // Only the delta (150 - 123 = 27) should be recorded + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(27L), any(), any()); + verify(recorder).recordDoubleHistogram( + eq(Objects.requireNonNull(InternalTcpMetrics.MIN_RTT_INSTRUMENT)), + eq(0.006), any(), any()); + verify(recorder, org.mockito.Mockito.never()) + .addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT)), + anyLong(), any(), any()); + } + + @Test + public void tracker_periodicRecord_doesNotReportZeroDeltaForTotalRetrans() throws Exception { + MetricRecorder recorder = mock(MetricRecorder.class); + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + TcpMetrics tracker = new TcpMetrics(recorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + infoSource.setValues(123, 4, 5000); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + tracker.channelActive(channel); + + ScheduledFuture timer = tracker.getReportTimer(); + assertNotNull("Timer should be scheduled", timer); + + long delay = timer.getDelay(TimeUnit.MILLISECONDS); + channel.advanceTimeBy(delay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(123L), any(), any()); + + org.mockito.Mockito.clearInvocations(recorder); + + // Keep tcpInfo the same for second periodic record + ScheduledFuture newTimer = tracker.getReportTimer(); + assertNotNull("New timer should be scheduled", newTimer); + assertNotSame("Timer should be a new instance", timer, newTimer); + long newDelay = newTimer.getDelay(TimeUnit.MILLISECONDS); + channel.advanceTimeBy(newDelay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + // NO delta (123 - 123 = 0), so it should not be recorded + verify(recorder, org.mockito.Mockito.never()) + .addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + anyLong(), any(), any()); + + // MIN_RTT should be recorded again! + verify(recorder).recordDoubleHistogram( + eq(Objects.requireNonNull(InternalTcpMetrics.MIN_RTT_INSTRUMENT)), + eq(0.005), any(), any()); + } + + public static class ConfigurableFakeWithTcpInfo extends + io.netty.channel.embedded.EmbeddedChannel { + private final FakeEpollTcpInfo infoToCopy; + + public ConfigurableFakeWithTcpInfo(FakeEpollTcpInfo infoToCopy) { + this.infoToCopy = infoToCopy; + } + + public void tcpInfo(FakeEpollTcpInfo info) { + info.totalRetrans = infoToCopy.totalRetrans; + info.retransmits = infoToCopy.retransmits; + info.rtt = infoToCopy.rtt; + } + } + + private static class AddressOverrideEmbeddedChannel extends + io.netty.channel.embedded.EmbeddedChannel { + private final SocketAddress local; + private final SocketAddress remote; + + public AddressOverrideEmbeddedChannel(SocketAddress local, SocketAddress remote) { + this.local = local; + this.remote = remote; + } + + @Override + public SocketAddress localAddress() { + return local; + } + + @Override + public SocketAddress remoteAddress() { + return remote; + } + } + + @Test + public void tracker_reportsDeltas_correctly() throws Exception { + MetricRecorder recorder = mock(MetricRecorder.class); + + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + TcpMetrics tracker = new TcpMetrics(recorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + // 10 retransmits total + infoSource.setValues(10, 2, 1000); + tracker.recordTcpInfo(channel); + + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(10L), any(), any()); + + // 15 retransmits total (delta 5) + infoSource.setValues(15, 0, 1000); + tracker.recordTcpInfo(channel); + + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(5L), any(), any()); + + // 15 retransmits total (delta 0) - should NOT report + // also set retransmits to 1 + infoSource.setValues(15, 1, 1000); + tracker.recordTcpInfo(channel); + // Verify no new interactions with this specific metric and value + // We can't easily verify "no interaction" for specific value without capturing. + verify(recorder, org.mockito.Mockito.times(1)).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(10L), any(), any()); + verify(recorder, org.mockito.Mockito.times(1)).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(5L), any(), any()); + // Total interactions for packetsRetransmitted should be 2 + verify(recorder, org.mockito.Mockito.times(2)).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + anyLong(), any(), any()); + + // recurringRetransmits should NOT have been reported yet (periodic calls) + verify(recorder, org.mockito.Mockito.times(0)).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT)), + anyLong(), any(), any()); + + // Close channel - should report recurringRetransmits + tracker.channelInactive(channel); + verify(recorder, org.mockito.Mockito.times(1)).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT)), + eq(1L), // From last infoSource setValues(15, 1, 1000) + any(), any()); + } + + @Test + public void tracker_recordTcpInfo_reflectionFailure() { + MetricRecorder recorder = mock(MetricRecorder.class); + + TcpMetrics.epollInfo = null; + TcpMetrics tracker = new TcpMetrics(recorder); + + io.netty.channel.embedded.EmbeddedChannel channel = new + io.netty.channel.embedded.EmbeddedChannel(); + + // Should catch exception and ignore + tracker.channelInactive(channel); + } + + @Test + public void registeredMetrics_haveCorrectOptionalLabels() { + List expectedOptionalLabels = Arrays.asList( + "network.local.address", + "network.local.port", + "network.peer.address", + "network.peer.port"); + + assertEquals( + expectedOptionalLabels, + InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT.getOptionalLabelKeys()); + assertEquals( + expectedOptionalLabels, + InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT.getOptionalLabelKeys()); + + assertEquals( + expectedOptionalLabels, + Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT) + .getOptionalLabelKeys()); + assertEquals( + expectedOptionalLabels, + Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT) + .getOptionalLabelKeys()); + assertEquals( + expectedOptionalLabels, + Objects.requireNonNull(InternalTcpMetrics.MIN_RTT_INSTRUMENT).getOptionalLabelKeys()); + } + + @Test + public void channelActive_extractsLabels_ipv4() throws Exception { + InetAddress localInet = InetAddress.getByAddress(new byte[] {127, 0, 0, 1}); + InetAddress remoteInet = InetAddress.getByAddress(new byte[] {127, 0, 0, 2}); + + AddressOverrideEmbeddedChannel channel = new AddressOverrideEmbeddedChannel( + new InetSocketAddress(localInet, 8080), + new InetSocketAddress(remoteInet, 443)); + + metrics.channelActive(channel); + + verify(metricRecorder).addLongCounter( + eq(InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT), eq(1L), + eq(Collections.emptyList()), + eq(Arrays.asList("127.0.0.1", "8080", "127.0.0.2", "443"))); + verify(metricRecorder).addLongUpDownCounter( + eq(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT), eq(1L), + eq(Collections.emptyList()), + eq(Arrays.asList("127.0.0.1", "8080", "127.0.0.2", "443"))); + verifyNoMoreInteractions(metricRecorder); + } + + @Test + public void channelInactive_extractsLabels_ipv6() throws Exception { + InetAddress localInet = InetAddress.getByAddress(new byte[] {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1}); + InetAddress remoteInet = InetAddress.getByAddress(new byte[] {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2}); + + AddressOverrideEmbeddedChannel channel = new AddressOverrideEmbeddedChannel( + new InetSocketAddress(localInet, 8080), + new InetSocketAddress(remoteInet, 443)); + + metrics.channelInactive(channel); + + verify(metricRecorder).addLongUpDownCounter( + eq(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT), eq(-1L), + eq(Collections.emptyList()), + eq(Arrays.asList("0:0:0:0:0:0:0:1", "8080", "0:0:0:0:0:0:0:2", "443"))); + verifyNoMoreInteractions(metricRecorder); + } + + @Test + public void channelActive_extractsLabels_nonInetAddress() { + SocketAddress dummyAddress = new SocketAddress() { + }; + AddressOverrideEmbeddedChannel channel = new AddressOverrideEmbeddedChannel( + dummyAddress, dummyAddress); + + metrics.channelActive(channel); + + verify(metricRecorder).addLongCounter( + eq(InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT), eq(1L), + eq(Collections.emptyList()), + eq(Arrays.asList("", "", "", ""))); + verify(metricRecorder).addLongUpDownCounter( + eq(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT), eq(1L), + eq(Collections.emptyList()), + eq(Arrays.asList("", "", "", ""))); + verifyNoMoreInteractions(metricRecorder); + } + + @Test + public void channelActive_incrementsCounts() { + metrics.channelActive(channel); + verify(metricRecorder).addLongCounter( + eq(InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT), eq(1L), + eq(Collections.emptyList()), + eq(Arrays.asList("", "", "", ""))); + verify(metricRecorder).addLongUpDownCounter( + eq(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT), eq(1L), + eq(Collections.emptyList()), + eq(Arrays.asList("", "", "", ""))); + verifyNoMoreInteractions(metricRecorder); + } + + @Test + public void channelInactive_decrementsCount_noEpoll_noError() { + metrics.channelInactive(channel); + verify(metricRecorder).addLongUpDownCounter( + eq(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT), eq(-1L), + eq(Collections.emptyList()), + eq(Arrays.asList("", "", "", ""))); + verifyNoMoreInteractions(metricRecorder); + } + + @Test + public void channelActive_schedulesReportTimer() throws Exception { + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + + metrics = new TcpMetrics(metricRecorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + metrics.channelActive(channel); + + ScheduledFuture timer = metrics.getReportTimer(); + assertNotNull("Timer should be scheduled", timer); + + long delay = timer.getDelay(TimeUnit.MILLISECONDS); + assertTrue("Delay should be >= 30000 but was " + delay, delay >= 30_000); + assertTrue("Delay should be <= 330000 but was " + delay, delay <= 330_000); + + // Advance time to trigger the task + channel.advanceTimeBy(delay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + // Verify rescheduling + ScheduledFuture newTimer = metrics.getReportTimer(); + assertNotNull("New timer should be scheduled", newTimer); + assertNotSame("Timer should be a new instance", timer, newTimer); + + long newDelay = newTimer.getDelay(TimeUnit.MILLISECONDS); + // Re-arming jitter is 90% to 110%, so 270,000 ms to 330,000 ms + assertTrue("Delay should be >= 270000 but was " + newDelay, newDelay >= 270_000); + assertTrue("Delay should be <= 330000 but was " + newDelay, newDelay <= 330_000); + } + + @Test + public void channelInactive_cancelsReportTimer() throws Exception { + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + + metrics = new TcpMetrics(metricRecorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + metrics.channelActive(channel); + + ScheduledFuture timer = metrics.getReportTimer(); + assertNotNull("Timer should be scheduled", timer); + + metrics.channelInactive(channel); + + assertTrue("Timer should be cancelled", timer.isCancelled()); + } +} diff --git a/netty/src/test/java/io/grpc/netty/UdsNameResolverProviderTest.java b/netty/src/test/java/io/grpc/netty/UdsNameResolverProviderTest.java index 6a329c8fc68..1766a8e4134 100644 --- a/netty/src/test/java/io/grpc/netty/UdsNameResolverProviderTest.java +++ b/netty/src/test/java/io/grpc/netty/UdsNameResolverProviderTest.java @@ -17,19 +17,30 @@ package io.grpc.netty; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.TruthJUnit.assume; import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import io.grpc.ChannelLogger; import io.grpc.EquivalentAddressGroup; import io.grpc.NameResolver; +import io.grpc.NameResolver.ServiceConfigParser; +import io.grpc.SynchronizationContext; +import io.grpc.Uri; +import io.grpc.internal.FakeClock; +import io.grpc.internal.GrpcUtil; import io.netty.channel.unix.DomainSocketAddress; import java.net.SocketAddress; import java.net.URI; +import java.util.Arrays; import java.util.List; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; @@ -37,8 +48,16 @@ import org.mockito.junit.MockitoRule; /** Unit tests for {@link UdsNameResolverProvider}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class UdsNameResolverProviderTest { + private static final int DEFAULT_PORT = 887; + + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @@ -51,56 +70,81 @@ public class UdsNameResolverProviderTest { UdsNameResolverProvider udsNameResolverProvider = new UdsNameResolverProvider(); + private final SynchronizationContext syncContext = new SynchronizationContext( + (t, e) -> { + throw new AssertionError(e); + }); + private final FakeClock fakeExecutor = new FakeClock(); + private final NameResolver.Args args = NameResolver.Args.newBuilder() + .setDefaultPort(DEFAULT_PORT) + .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR) + .setSynchronizationContext(syncContext) + .setServiceConfigParser(mock(ServiceConfigParser.class)) + .setChannelLogger(mock(ChannelLogger.class)) + .setScheduledExecutorService(fakeExecutor.getScheduledExecutorService()) + .build(); @Test public void testUnixRelativePath() { - UdsNameResolver udsNameResolver = - udsNameResolverProvider.newNameResolver(URI.create("unix:sock.sock"), null); - assertThat(udsNameResolver).isNotNull(); - udsNameResolver.start(mockListener); - verify(mockListener).onResult(resultCaptor.capture()); - NameResolver.ResolutionResult result = resultCaptor.getValue(); - List list = result.getAddresses(); - assertThat(list).isNotNull(); - assertThat(list).hasSize(1); - EquivalentAddressGroup eag = list.get(0); - assertThat(eag).isNotNull(); - List addresses = eag.getAddresses(); - assertThat(addresses).hasSize(1); - assertThat(addresses.get(0)).isInstanceOf(DomainSocketAddress.class); - DomainSocketAddress domainSocketAddress = (DomainSocketAddress) addresses.get(0); + UdsNameResolver udsNameResolver = newNameResolver("unix:sock.sock", args); + DomainSocketAddress domainSocketAddress = startAndGetUniqueResolvedAddress(udsNameResolver); assertThat(domainSocketAddress.path()).isEqualTo("sock.sock"); } @Test public void testUnixAbsolutePath() { - UdsNameResolver udsNameResolver = - udsNameResolverProvider.newNameResolver(URI.create("unix:/sock.sock"), null); - assertThat(udsNameResolver).isNotNull(); - udsNameResolver.start(mockListener); - verify(mockListener).onResult(resultCaptor.capture()); - NameResolver.ResolutionResult result = resultCaptor.getValue(); - List list = result.getAddresses(); - assertThat(list).isNotNull(); - assertThat(list).hasSize(1); - EquivalentAddressGroup eag = list.get(0); - assertThat(eag).isNotNull(); - List addresses = eag.getAddresses(); - assertThat(addresses).hasSize(1); - assertThat(addresses.get(0)).isInstanceOf(DomainSocketAddress.class); - DomainSocketAddress domainSocketAddress = (DomainSocketAddress) addresses.get(0); + UdsNameResolver udsNameResolver = newNameResolver("unix:/sock.sock", args); + DomainSocketAddress domainSocketAddress = startAndGetUniqueResolvedAddress(udsNameResolver); assertThat(domainSocketAddress.path()).isEqualTo("/sock.sock"); } @Test public void testUnixAbsoluteAlternatePath() { - UdsNameResolver udsNameResolver = - udsNameResolverProvider.newNameResolver(URI.create("unix:///sock.sock"), null); + UdsNameResolver udsNameResolver = newNameResolver("unix:///sock.sock", args); + DomainSocketAddress domainSocketAddress = startAndGetUniqueResolvedAddress(udsNameResolver); + assertThat(domainSocketAddress.path()).isEqualTo("/sock.sock"); + } + + @Test + public void testUnixPathWithAuthority() { + try { + newNameResolver("unix://localhost/sock.sock", args); + fail("exception expected"); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().isEqualTo("authority not supported: localhost"); + } + } + + @Test + public void testUnixAbsolutePathDoesNotIncludeQueryOrFragment() { + UdsNameResolver udsNameResolver = newNameResolver("unix:///sock.sock?query#fragment", args); + DomainSocketAddress domainSocketAddress = startAndGetUniqueResolvedAddress(udsNameResolver); + assertThat(domainSocketAddress.path()).isEqualTo("/sock.sock"); + } + + @Test + public void testUnixRelativePathDoesNotIncludeQueryOrFragment() { + // This test fails without RFC 3986 support because of a bug in the legacy java.net.URI-based + // NRP implementation. + assume().that(enableRfc3986UrisParam).isTrue(); + + UdsNameResolver udsNameResolver = newNameResolver("unix:sock.sock?query#fragment", args); + DomainSocketAddress domainSocketAddress = startAndGetUniqueResolvedAddress(udsNameResolver); + assertThat(domainSocketAddress.path()).isEqualTo("sock.sock"); + } + + private UdsNameResolver newNameResolver(String uriString, NameResolver.Args args) { + return enableRfc3986UrisParam + ? (UdsNameResolver) udsNameResolverProvider.newNameResolver(Uri.create(uriString), args) + : udsNameResolverProvider.newNameResolver(URI.create(uriString), args); + } + + private DomainSocketAddress startAndGetUniqueResolvedAddress(UdsNameResolver udsNameResolver) { assertThat(udsNameResolver).isNotNull(); udsNameResolver.start(mockListener); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); NameResolver.ResolutionResult result = resultCaptor.getValue(); - List list = result.getAddresses(); + List list = result.getAddressesOrError().getValue(); assertThat(list).isNotNull(); assertThat(list).hasSize(1); EquivalentAddressGroup eag = list.get(0); @@ -108,17 +152,6 @@ public void testUnixAbsoluteAlternatePath() { List addresses = eag.getAddresses(); assertThat(addresses).hasSize(1); assertThat(addresses.get(0)).isInstanceOf(DomainSocketAddress.class); - DomainSocketAddress domainSocketAddress = (DomainSocketAddress) addresses.get(0); - assertThat(domainSocketAddress.path()).isEqualTo("/sock.sock"); - } - - @Test - public void testUnixPathWithAuthority() { - try { - udsNameResolverProvider.newNameResolver(URI.create("unix://localhost/sock.sock"), null); - fail("exception expected"); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().isEqualTo("non-null authority not supported"); - } + return (DomainSocketAddress) addresses.get(0); } } diff --git a/netty/src/test/java/io/grpc/netty/UdsNameResolverTest.java b/netty/src/test/java/io/grpc/netty/UdsNameResolverTest.java index 8eb010e23e5..7bf808c18ce 100644 --- a/netty/src/test/java/io/grpc/netty/UdsNameResolverTest.java +++ b/netty/src/test/java/io/grpc/netty/UdsNameResolverTest.java @@ -18,10 +18,16 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import io.grpc.ChannelLogger; import io.grpc.EquivalentAddressGroup; import io.grpc.NameResolver; +import io.grpc.NameResolver.ServiceConfigParser; +import io.grpc.SynchronizationContext; +import io.grpc.internal.FakeClock; +import io.grpc.internal.GrpcUtil; import io.netty.channel.unix.DomainSocketAddress; import java.net.SocketAddress; import java.util.List; @@ -41,7 +47,20 @@ public class UdsNameResolverTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - + private static final int DEFAULT_PORT = 887; + private final FakeClock fakeExecutor = new FakeClock(); + private final SynchronizationContext syncContext = new SynchronizationContext( + (t, e) -> { + throw new AssertionError(e); + }); + private final NameResolver.Args args = NameResolver.Args.newBuilder() + .setDefaultPort(DEFAULT_PORT) + .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR) + .setSynchronizationContext(syncContext) + .setServiceConfigParser(mock(ServiceConfigParser.class)) + .setChannelLogger(mock(ChannelLogger.class)) + .setScheduledExecutorService(fakeExecutor.getScheduledExecutorService()) + .build(); @Mock private NameResolver.Listener2 mockListener; @@ -52,11 +71,11 @@ public class UdsNameResolverTest { @Test public void testValidTargetPath() { - udsNameResolver = new UdsNameResolver(null, "sock.sock"); + udsNameResolver = new UdsNameResolver(null, "sock.sock", args); udsNameResolver.start(mockListener); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); NameResolver.ResolutionResult result = resultCaptor.getValue(); - List list = result.getAddresses(); + List list = result.getAddressesOrError().getValue(); assertThat(list).isNotNull(); assertThat(list).hasSize(1); EquivalentAddressGroup eag = list.get(0); @@ -72,10 +91,10 @@ public void testValidTargetPath() { @Test public void testNonNullAuthority() { try { - udsNameResolver = new UdsNameResolver("authority", "sock.sock"); + udsNameResolver = new UdsNameResolver("somehost", "sock.sock", args); fail("exception expected"); } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().isEqualTo("non-null authority not supported"); + assertThat(e).hasMessageThat().isEqualTo("authority not supported: somehost"); } } } diff --git a/okhttp/BUILD.bazel b/okhttp/BUILD.bazel index 80068c9bb5b..74a9f7a4300 100644 --- a/okhttp/BUILD.bazel +++ b/okhttp/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( diff --git a/okhttp/build.gradle b/okhttp/build.gradle index 063e4775de1..6c542feec9c 100644 --- a/okhttp/build.gradle +++ b/okhttp/build.gradle @@ -31,8 +31,16 @@ dependencies { project(':grpc-testing-proto'), libraries.netty.codec.http2, libraries.okhttp - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } project.sourceSets { diff --git a/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java index 1ac64d7ebb5..01ee23b905c 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java +++ b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.internal.SerializingExecutor; import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; import io.grpc.okhttp.internal.framed.ErrorCode; @@ -30,7 +31,6 @@ import java.io.IOException; import java.net.Socket; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import okio.Buffer; import okio.Sink; import okio.Timeout; diff --git a/okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java b/okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java new file mode 100644 index 00000000000..6e6a6f12a39 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java @@ -0,0 +1,117 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import java.io.IOException; +import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; + +/** A no-op ssl socket, to facilitate overriding only the required methods in specific + * implementations. + */ +class NoopSslSocket extends SSLSocket { + @Override + public String[] getSupportedCipherSuites() { + return new String[0]; + } + + @Override + public String[] getEnabledCipherSuites() { + return new String[0]; + } + + @Override + public void setEnabledCipherSuites(String[] suites) { + + } + + @Override + public String[] getSupportedProtocols() { + return new String[0]; + } + + @Override + public String[] getEnabledProtocols() { + return new String[0]; + } + + @Override + public void setEnabledProtocols(String[] protocols) { + + } + + @Override + public SSLSession getSession() { + return null; + } + + @Override + public void addHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void startHandshake() throws IOException { + + } + + @Override + public void setUseClientMode(boolean mode) { + + } + + @Override + public boolean getUseClientMode() { + return false; + } + + @Override + public void setNeedClientAuth(boolean need) { + + } + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean want) { + + } + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean flag) { + + } + + @Override + public boolean getEnableSessionCreation() { + return false; + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index 15508110344..98f764132fe 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -17,11 +17,13 @@ package io.grpc.okhttp; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.internal.CertificateUtils.createTrustManager; import static io.grpc.internal.GrpcUtil.DEFAULT_KEEPALIVE_TIMEOUT_NANOS; import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.CallCredentials; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; @@ -72,7 +74,6 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; @@ -81,8 +82,6 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; -import javax.security.auth.x500.X500Principal; /** Convenience class for building channels with the OkHttp transport. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785") @@ -91,6 +90,7 @@ public final class OkHttpChannelBuilder extends ForwardingChannelBuilder2> getSupportedSocketAddressTypes() { return Collections.singleton(InetSocketAddress.class); } @@ -799,6 +782,7 @@ static final class OkHttpTransportFactory implements ClientTransportFactory { private final boolean keepAliveWithoutCalls; final int maxInboundMetadataSize; final boolean useGetForSafeMethods; + private final ChannelCredentials channelCredentials; private boolean closed; private OkHttpTransportFactory( @@ -816,7 +800,8 @@ private OkHttpTransportFactory( boolean keepAliveWithoutCalls, int maxInboundMetadataSize, TransportTracer.Factory transportTracerFactory, - boolean useGetForSafeMethods) { + boolean useGetForSafeMethods, + ChannelCredentials channelCredentials) { this.executorPool = executorPool; this.executor = executorPool.getObject(); this.scheduledExecutorServicePool = scheduledExecutorServicePool; @@ -834,6 +819,7 @@ private OkHttpTransportFactory( this.keepAliveWithoutCalls = keepAliveWithoutCalls; this.maxInboundMetadataSize = maxInboundMetadataSize; this.useGetForSafeMethods = useGetForSafeMethods; + this.channelCredentials = channelCredentials; this.transportTracerFactory = Preconditions.checkNotNull(transportTracerFactory, "transportTracerFactory"); @@ -861,7 +847,8 @@ public void run() { options.getUserAgent(), options.getEagAttributes(), options.getHttpConnectProxiedSocketAddress(), - tooManyPingsRunnable); + tooManyPingsRunnable, + channelCredentials); if (enableKeepAlive) { transport.enableKeepAlive( true, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos, keepAliveWithoutCalls); @@ -897,7 +884,8 @@ public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials ch keepAliveWithoutCalls, maxInboundMetadataSize, transportTracerFactory, - useGetForSafeMethods); + useGetForSafeMethods, + channelCredentials); return new SwapChannelCredentialsResult(factory, result.callCredentials); } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java index 9d9fe160715..8dd55d9f23e 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java @@ -21,6 +21,7 @@ import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import com.google.common.io.BaseEncoding; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Metadata; @@ -37,7 +38,6 @@ import io.perfmark.Tag; import io.perfmark.TaskCloseable; import java.util.List; -import javax.annotation.concurrent.GuardedBy; import okio.Buffer; /** @@ -409,7 +409,7 @@ private void streamReady(Metadata metadata, String path) { transport.isUsingPlaintext()); // TODO(b/145386688): This access should be guarded by 'this.transport.lock'; instead found: // 'this.lock' - transport.streamReadyToStart(OkHttpClientStream.this); + transport.streamReadyToStart(OkHttpClientStream.this, authority); } Tag tag() { diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index 29d3dbc1cdf..4764a6a1387 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -27,8 +27,10 @@ import com.google.common.base.Supplier; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ChannelCredentials; import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.HttpConnectProxiedSocketAddress; @@ -42,20 +44,28 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StatusException; +import io.grpc.TlsChannelCredentials; +import io.grpc.internal.CertificateUtils; import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.ConnectionClientTransport; +import io.grpc.internal.DisconnectError; +import io.grpc.internal.GoAwayDisconnectError; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; import io.grpc.internal.Http2Ping; import io.grpc.internal.InUseStateAggregator; import io.grpc.internal.KeepAliveManager; import io.grpc.internal.KeepAliveManager.ClientKeepAlivePinger; +import io.grpc.internal.NoopSslSession; import io.grpc.internal.SerializingExecutor; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; +import io.grpc.okhttp.OkHttpChannelBuilder.OkHttpTransportFactory; import io.grpc.okhttp.internal.ConnectionSpec; import io.grpc.okhttp.internal.Credentials; +import io.grpc.okhttp.internal.OkHostnameVerifier; import io.grpc.okhttp.internal.StatusLine; import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.FrameReader; @@ -70,31 +80,46 @@ import io.perfmark.PerfMark; import java.io.EOFException; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.Socket; import java.net.URI; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; import java.util.Collections; import java.util.Deque; import java.util.EnumMap; import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Random; +import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; import okio.Buffer; import okio.BufferedSink; import okio.BufferedSource; @@ -107,9 +132,15 @@ * A okhttp-based {@link ConnectionClientTransport} implementation. */ class OkHttpClientTransport implements ConnectionClientTransport, TransportExceptionHandler, - OutboundFlowController.Transport { + OutboundFlowController.Transport, ClientKeepAlivePinger.TransportWithDisconnectReason { private static final Map ERROR_CODE_TO_STATUS = buildErrorCodeToStatusMap(); private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); + private static final String GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK = + "GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK"; + static boolean enablePerRpcAuthorityCheck = + GrpcUtil.getFlag(GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK, false); + private Socket sock; + private SSLSession sslSession; private static Map buildErrorCodeToStatusMap() { Map errorToStatus = new EnumMap<>(ErrorCode.class); @@ -140,6 +171,26 @@ private static Map buildErrorCodeToStatusMap() { return Collections.unmodifiableMap(errorToStatus); } + private static final Class x509ExtendedTrustManagerClass; + private static final Method checkServerTrustedMethod; + + static { + Class x509ExtendedTrustManagerClass1 = null; + Method checkServerTrustedMethod1 = null; + try { + x509ExtendedTrustManagerClass1 = Class.forName("javax.net.ssl.X509ExtendedTrustManager"); + checkServerTrustedMethod1 = x509ExtendedTrustManagerClass1.getMethod("checkServerTrusted", + X509Certificate[].class, String.class, Socket.class); + } catch (ClassNotFoundException e) { + // Per-rpc authority override via call options will be disallowed. + } catch (NoSuchMethodException e) { + // Should never happen since X509ExtendedTrustManager was introduced in Android API level 24 + // along with checkServerTrusted. + } + x509ExtendedTrustManagerClass = x509ExtendedTrustManagerClass1; + checkServerTrustedMethod = checkServerTrustedMethod1; + } + private final InetSocketAddress address; private final String defaultAuthority; private final String userAgent; @@ -201,6 +252,19 @@ private static Map buildErrorCodeToStatusMap() { private final boolean useGetForSafeMethods; @GuardedBy("lock") private final TransportTracer transportTracer; + private final TrustManager x509TrustManager; + + @SuppressWarnings("serial") + private static class LruCache extends LinkedHashMap { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > 100; + } + } + + @GuardedBy("lock") + private final Map authorityVerificationResults = new LruCache<>(); + @GuardedBy("lock") private final InUseStateAggregator inUseState = new InUseStateAggregator() { @@ -229,13 +293,14 @@ protected void handleNotInUse() { SettableFuture connectedFuture; public OkHttpClientTransport( - OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, - InetSocketAddress address, - String authority, - @Nullable String userAgent, - Attributes eagAttrs, - @Nullable HttpConnectProxiedSocketAddress proxiedAddr, - Runnable tooManyPingsRunnable) { + OkHttpTransportFactory transportFactory, + InetSocketAddress address, + String authority, + @Nullable String userAgent, + Attributes eagAttrs, + @Nullable HttpConnectProxiedSocketAddress proxiedAddr, + Runnable tooManyPingsRunnable, + ChannelCredentials channelCredentials) { this( transportFactory, address, @@ -245,19 +310,21 @@ public OkHttpClientTransport( GrpcUtil.STOPWATCH_SUPPLIER, new Http2(), proxiedAddr, - tooManyPingsRunnable); + tooManyPingsRunnable, + channelCredentials); } private OkHttpClientTransport( - OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, - InetSocketAddress address, - String authority, - @Nullable String userAgent, - Attributes eagAttrs, - Supplier stopwatchFactory, - Variant variant, - @Nullable HttpConnectProxiedSocketAddress proxiedAddr, - Runnable tooManyPingsRunnable) { + OkHttpTransportFactory transportFactory, + InetSocketAddress address, + String authority, + @Nullable String userAgent, + Attributes eagAttrs, + Supplier stopwatchFactory, + Variant variant, + @Nullable HttpConnectProxiedSocketAddress proxiedAddr, + Runnable tooManyPingsRunnable, + ChannelCredentials channelCredentials) { this.address = Preconditions.checkNotNull(address, "address"); this.defaultAuthority = authority; this.maxMessageSize = transportFactory.maxMessageSize; @@ -272,7 +339,8 @@ private OkHttpClientTransport( this.socketFactory = transportFactory.socketFactory == null ? SocketFactory.getDefault() : transportFactory.socketFactory; this.sslSocketFactory = transportFactory.sslSocketFactory; - this.hostnameVerifier = transportFactory.hostnameVerifier; + this.hostnameVerifier = transportFactory.hostnameVerifier != null + ? transportFactory.hostnameVerifier : OkHostnameVerifier.INSTANCE; this.connectionSpec = Preconditions.checkNotNull( transportFactory.connectionSpec, "connectionSpec"); this.stopwatchFactory = Preconditions.checkNotNull(stopwatchFactory, "stopwatchFactory"); @@ -288,6 +356,21 @@ private OkHttpClientTransport( .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build(); this.useGetForSafeMethods = transportFactory.useGetForSafeMethods; initTransportTracer(); + TrustManager tempX509TrustManager; + if (channelCredentials instanceof TlsChannelCredentials + && x509ExtendedTrustManagerClass != null) { + try { + tempX509TrustManager = getTrustManager( + (TlsChannelCredentials) channelCredentials); + } catch (GeneralSecurityException e) { + tempX509TrustManager = null; + log.log(Level.WARNING, "Obtaining X509ExtendedTrustManager for the transport failed." + + "Per-rpc authority overrides will be disallowed.", e); + } + } else { + tempX509TrustManager = null; + } + x509TrustManager = tempX509TrustManager; } /** @@ -296,7 +379,7 @@ private OkHttpClientTransport( @SuppressWarnings("AddressSelection") // An IP address always returns one address @VisibleForTesting OkHttpClientTransport( - OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, + OkHttpTransportFactory transportFactory, String userAgent, Supplier stopwatchFactory, Variant variant, @@ -312,7 +395,8 @@ private OkHttpClientTransport( stopwatchFactory, variant, null, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); this.connectingCallback = connectingCallback; this.connectedFuture = Preconditions.checkNotNull(connectedFuture, "connectedFuture"); } @@ -392,6 +476,7 @@ public OkHttpClientStream newStream( Preconditions.checkNotNull(headers, "headers"); StatsTraceContext statsTraceContext = StatsTraceContext.newClientContext(tracers, getAttributes(), headers); + // FIXME: it is likely wrong to pass the transportTracer here as it'll exit the lock's scope synchronized (lock) { // to make @GuardedBy linter happy return new OkHttpClientStream( @@ -412,23 +497,116 @@ public OkHttpClientStream newStream( } } + private TrustManager getTrustManager(TlsChannelCredentials tlsCreds) + throws GeneralSecurityException { + TrustManager[] tm; + // Using the same way of creating TrustManager from OkHttpChannelBuilder.sslSocketFactoryFrom() + if (tlsCreds.getTrustManagers() != null) { + tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]); + } else if (tlsCreds.getRootCertificates() != null) { + tm = CertificateUtils.createTrustManager(tlsCreds.getRootCertificates()); + } else { // else use system default + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init((KeyStore) null); + tm = tmf.getTrustManagers(); + } + for (TrustManager trustManager: tm) { + if (trustManager instanceof X509TrustManager) { + return trustManager; + } + } + return null; + } + @GuardedBy("lock") - void streamReadyToStart(OkHttpClientStream clientStream) { + void streamReadyToStart(OkHttpClientStream clientStream, String authority) { if (goAwayStatus != null) { clientStream.transportState().transportReportStatus( goAwayStatus, RpcProgress.MISCARRIED, true, new Metadata()); - } else if (streams.size() >= maxConcurrentStreams) { - pendingStreams.add(clientStream); - setInUse(clientStream); } else { - startStream(clientStream); + if (socket instanceof SSLSocket && !authority.equals(defaultAuthority)) { + Status authorityVerificationResult; + if (authorityVerificationResults.containsKey(authority)) { + authorityVerificationResult = authorityVerificationResults.get(authority); + } else { + authorityVerificationResult = verifyAuthority(authority); + authorityVerificationResults.put(authority, authorityVerificationResult); + } + if (!authorityVerificationResult.isOk()) { + if (enablePerRpcAuthorityCheck) { + clientStream.transportState().transportReportStatus( + authorityVerificationResult, RpcProgress.PROCESSED, true, new Metadata()); + return; + } + } + } + if (streams.size() >= maxConcurrentStreams) { + pendingStreams.add(clientStream); + setInUse(clientStream); + } else { + startStream(clientStream); + } } } + private Status verifyAuthority(String authority) { + Status authorityVerificationResult; + if (hostnameVerifier.verify(authority, ((SSLSocket) socket).getSession())) { + authorityVerificationResult = Status.OK; + } else { + authorityVerificationResult = Status.UNAVAILABLE.withDescription(String.format( + "HostNameVerifier verification failed for authority '%s'", + authority)); + } + if (!authorityVerificationResult.isOk() && !enablePerRpcAuthorityCheck) { + log.log(Level.WARNING, String.format("HostNameVerifier verification failed for " + + "authority '%s'. This will be an error in the future.", + authority)); + } + if (authorityVerificationResult.isOk()) { + // The status is trivially assigned in this case, but we are still making use of the + // cache to keep track that a warning log had been logged for the authority when + // enablePerRpcAuthorityCheck is false. When we permanently enable the feature, the + // status won't need to be cached for case when x509TrustManager is null. + if (x509TrustManager == null) { + authorityVerificationResult = Status.UNAVAILABLE.withDescription( + String.format("Could not verify authority '%s' for the rpc with no " + + "X509TrustManager available", + authority)); + } else if (x509ExtendedTrustManagerClass.isInstance(x509TrustManager)) { + try { + Certificate[] peerCertificates = sslSession.getPeerCertificates(); + X509Certificate[] x509PeerCertificates = + new X509Certificate[peerCertificates.length]; + for (int i = 0; i < peerCertificates.length; i++) { + x509PeerCertificates[i] = (X509Certificate) peerCertificates[i]; + } + checkServerTrustedMethod.invoke(x509TrustManager, x509PeerCertificates, + "RSA", new SslSocketWrapper((SSLSocket) socket, authority)); + authorityVerificationResult = Status.OK; + } catch (SSLPeerUnverifiedException | InvocationTargetException + | IllegalAccessException e) { + authorityVerificationResult = Status.UNAVAILABLE.withCause(e).withDescription( + "Peer verification failed"); + } + if (authorityVerificationResult.getCause() != null) { + log.log(Level.WARNING, authorityVerificationResult.getDescription() + + ". This will be an error in the future.", + authorityVerificationResult.getCause()); + } else { + log.log(Level.WARNING, authorityVerificationResult.getDescription() + + ". This will be an error in the future."); + } + } + } + return authorityVerificationResult; + } + @SuppressWarnings("GuardedBy") @GuardedBy("lock") private void startStream(OkHttpClientStream stream) { - Preconditions.checkState( + checkState( stream.transportState().id() == OkHttpClientStream.ABSENT_ID, "StreamId already assigned"); streams.put(nextStreamId, stream); setInUse(stream); @@ -499,20 +677,18 @@ public Runnable start(Listener listener) { outboundFlow = new OutboundFlowController(this, frameWriter); } final CountDownLatch latch = new CountDownLatch(1); + final CountDownLatch latchForExtraThread = new CountDownLatch(1); + // The transport needs up to two threads to function once started, + // but only needs one during handshaking. Start another thread during handshaking + // to make sure there's still a free thread available. If the number of threads is exhausted, + // it is better to kill the transport than for all the transports to hang unable to send. + CyclicBarrier barrier = new CyclicBarrier(2); // Connecting in the serializingExecutor, so that some stream operations like synStream // will be executed after connected. + serializingExecutor.execute(new Runnable() { @Override public void run() { - // This is a hack to make sure the connection preface and initial settings to be sent out - // without blocking the start. By doing this essentially prevents potential deadlock when - // network is not available during startup while another thread holding lock to send the - // initial preface. - try { - latch.await(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } // Use closed source on failure so that the reader immediately shuts down. BufferedSource source = Okio.buffer(new Source() { @Override @@ -529,9 +705,23 @@ public Timeout timeout() { public void close() { } }); - Socket sock; - SSLSession sslSession = null; try { + // This is a hack to make sure the connection preface and initial settings to be sent out + // without blocking the start. By doing this essentially prevents potential deadlock when + // network is not available during startup while another thread holding lock to send the + // initial preface. + try { + latch.await(); + barrier.await(1000, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (TimeoutException | BrokenBarrierException e) { + startGoAway(0, ErrorCode.INTERNAL_ERROR, Status.UNAVAILABLE + .withDescription("Timed out waiting for second handshake thread. " + + "The transport executor pool may have run out of threads")); + return; + } + if (proxiedAddr == null) { sock = socketFactory.createSocket(address.getAddress(), address.getPort()); } else { @@ -575,6 +765,7 @@ sslSocketFactory, hostnameVerifier, sock, getOverridenHost(), getOverridenPort() return; } finally { clientFrameHandler = new ClientFrameHandler(variant.newReader(source, true)); + latchForExtraThread.countDown(); } synchronized (lock) { socket = Preconditions.checkNotNull(sock, "socket"); @@ -584,6 +775,21 @@ sslSocketFactory, hostnameVerifier, sock, getOverridenHost(), getOverridenPort() } } }); + + executor.execute(new Runnable() { + @Override + public void run() { + try { + barrier.await(1000, TimeUnit.MILLISECONDS); + latchForExtraThread.await(); + } catch (BrokenBarrierException | TimeoutException e) { + // Something bad happened, maybe too few threads available! + // This will be handled in the handshake thread. + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }); // Schedule to send connection preface & settings before any other write. try { sendConnectionPrefaceAndSettings(); @@ -597,13 +803,15 @@ public void run() { if (connectingCallback != null) { connectingCallback.run(); } - // ClientFrameHandler need to be started after connectionPreface / settings, otherwise it - // may send goAway immediately. - executor.execute(clientFrameHandler); synchronized (lock) { maxConcurrentStreams = Integer.MAX_VALUE; - startPendingStreams(); + checkState(pendingStreams.isEmpty(), + "Pending streams detected during transport start." + + " RPCs should not be started before transport is ready."); } + // ClientFrameHandler need to be started after connectionPreface / settings, otherwise it + // may send goAway immediately. + executor.execute(clientFrameHandler); if (connectedFuture != null) { connectedFuture.set(null); } @@ -787,13 +995,18 @@ public void shutdown(Status reason) { } goAwayStatus = reason; - listener.transportShutdown(goAwayStatus); + listener.transportShutdown(goAwayStatus, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); stopIfNecessary(); } } @Override public void shutdownNow(Status reason) { + shutdownNow(reason, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + } + + @Override + public void shutdownNow(Status reason, DisconnectError disconnectError) { shutdown(reason); synchronized (lock) { Iterator> it = streams.entrySet().iterator(); @@ -883,7 +1096,13 @@ private void startGoAway(int lastKnownStreamId, ErrorCode errorCode, Status stat synchronized (lock) { if (goAwayStatus == null) { goAwayStatus = status; - listener.transportShutdown(status); + GrpcUtil.Http2Error http2Error; + if (errorCode == null) { + http2Error = GrpcUtil.Http2Error.NO_ERROR; + } else { + http2Error = GrpcUtil.Http2Error.forCode(errorCode.httpCode); + } + listener.transportShutdown(status, new GoAwayDisconnectError(http2Error)); } if (errorCode != null && !goAwaySent) { // Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated @@ -1028,12 +1247,12 @@ private void setInUse(OkHttpClientStream stream) { } } - private Throwable getPingFailure() { + private Status getPingFailure() { synchronized (lock) { if (goAwayStatus != null) { - return goAwayStatus.asException(); + return goAwayStatus; } else { - return Status.UNAVAILABLE.withDescription("Connection closed").asException(); + return Status.UNAVAILABLE.withDescription("Connection closed"); } } } @@ -1426,4 +1645,50 @@ public void alternateService(int streamId, String origin, ByteString protocol, S // TODO(madongfly): Deal with alternateService propagation } } + + /** + * SSLSocket wrapper that provides a fake SSLSession for handshake session. + */ + static final class SslSocketWrapper extends NoopSslSocket { + + private final SSLSession sslSession; + private final SSLSocket sslSocket; + + SslSocketWrapper(SSLSocket sslSocket, String peerHost) { + this.sslSocket = sslSocket; + this.sslSession = new FakeSslSession(peerHost); + } + + @Override + public SSLSession getHandshakeSession() { + return this.sslSession; + } + + @Override + public boolean isConnected() { + return sslSocket.isConnected(); + } + + @Override + public SSLParameters getSSLParameters() { + return sslSocket.getSSLParameters(); + } + } + + /** + * Fake SSLSession instance that provides the peer host name to verify for per-rpc check. + */ + static class FakeSslSession extends NoopSslSession { + + private final String peerHost; + + FakeSslSession(String peerHost) { + this.peerHost = peerHost; + } + + @Override + public String getPeerHost() { + return peerHost; + } + } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpProtocolNegotiator.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpProtocolNegotiator.java index d09d6cccedd..0706a39d028 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpProtocolNegotiator.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpProtocolNegotiator.java @@ -19,6 +19,8 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.net.HostAndPort; +import com.google.common.net.InetAddresses; import io.grpc.internal.GrpcUtil; import io.grpc.okhttp.internal.OptionalMethod; import io.grpc.okhttp.internal.Platform; @@ -247,7 +249,9 @@ protected void configureTlsExtensions( } else { SET_USE_SESSION_TICKETS.invokeOptionalWithoutCheckedException(sslSocket, true); } - if (SET_SERVER_NAMES != null && SNI_HOST_NAME != null) { + if (SET_SERVER_NAMES != null + && SNI_HOST_NAME != null + && !InetAddresses.isInetAddress(HostAndPort.fromString(hostname).getHost())) { SET_SERVER_NAMES .invoke(sslParams, Collections.singletonList(SNI_HOST_NAME.newInstance(hostname))); } else { diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpReadableBuffer.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpReadableBuffer.java index 136ee8954a2..d65453722f0 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpReadableBuffer.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpReadableBuffer.java @@ -21,7 +21,6 @@ import java.io.EOFException; import java.io.IOException; import java.io.OutputStream; -import java.nio.ByteBuffer; /** * A {@link ReadableBuffer} implementation that is backed by an {@link okio.Buffer}. @@ -71,12 +70,6 @@ public void readBytes(byte[] dest, int destOffset, int length) { } } - @Override - public void readBytes(ByteBuffer dest) { - // We are not using it. - throw new UnsupportedOperationException(); - } - @Override public void readBytes(OutputStream dest, int length) throws IOException { buffer.writeTo(dest, length); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java index 068474d70bc..163d2023b1c 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java @@ -17,6 +17,7 @@ package io.grpc.okhttp; import static com.google.common.base.Preconditions.checkArgument; +import static io.grpc.internal.CertificateUtils.createTrustManager; import com.google.common.base.Preconditions; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -26,6 +27,7 @@ import io.grpc.ForwardingServerBuilder; import io.grpc.InsecureServerCredentials; import io.grpc.Internal; +import io.grpc.MetricRecorder; import io.grpc.ServerBuilder; import io.grpc.ServerCredentials; import io.grpc.ServerStreamTracer; @@ -110,7 +112,15 @@ public static OkHttpServerBuilder forPort(SocketAddress address, ServerCredentia return new OkHttpServerBuilder(address, result.factory); } - final ServerImplBuilder serverImplBuilder = new ServerImplBuilder(this::buildTransportServers); + final ServerImplBuilder serverImplBuilder = new ServerImplBuilder( + new ServerImplBuilder.ClientTransportServersBuilder() { + @Override + public InternalServer buildClientTransportServers( + List streamTracerFactories, + MetricRecorder metricRecorder) { + return buildTransportServers(streamTracerFactories); + } + }); final SocketAddress listenAddress; final HandshakerSocketFactory handshakerSocketFactory; TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory(); @@ -425,7 +435,7 @@ static HandshakerSocketFactoryResult handshakerSocketFactoryFrom(ServerCredentia tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]); } else if (tlsCreds.getRootCertificates() != null) { try { - tm = OkHttpChannelBuilder.createTrustManager(tlsCreds.getRootCertificates()); + tm = createTrustManager(tlsCreds.getRootCertificates()); } catch (GeneralSecurityException gse) { log.log(Level.FINE, "Exception loading root certificates from credential", gse); return HandshakerSocketFactoryResult.error( diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java index bcf8837b7eb..d1f1a3f4fe0 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java @@ -17,6 +17,7 @@ package io.grpc.okhttp; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.Metadata; import io.grpc.Status; @@ -30,7 +31,6 @@ import io.perfmark.Tag; import io.perfmark.TaskCloseable; import java.util.List; -import javax.annotation.concurrent.GuardedBy; import okio.Buffer; /** diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java index 2da041f571e..7d192b16943 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java @@ -20,8 +20,10 @@ import static io.grpc.okhttp.OkHttpServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED; import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.InternalChannelz; import io.grpc.InternalLogId; @@ -51,6 +53,7 @@ import java.io.IOException; import java.net.Socket; import java.net.SocketException; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; @@ -62,7 +65,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import okio.Buffer; import okio.BufferedSource; import okio.ByteString; @@ -91,6 +93,7 @@ final class OkHttpServerTransport implements ServerTransport, private static final ByteString TE_TRAILERS = ByteString.encodeUtf8("trailers"); private static final ByteString CONTENT_TYPE = ByteString.encodeUtf8("content-type"); private static final ByteString CONTENT_LENGTH = ByteString.encodeUtf8("content-length"); + private static final ByteString ALLOW = ByteString.encodeUtf8("allow"); private final Config config; private final Variant variant = new Http2(); @@ -772,8 +775,9 @@ public void headers(boolean outFinished, } if (!POST_METHOD.equals(httpMethod)) { + List
extraHeaders = Lists.newArrayList(new Header(ALLOW, POST_METHOD)); respondWithHttpError(streamId, inFinished, 405, Status.Code.INTERNAL, - "HTTP Method is not supported: " + asciiString(httpMethod)); + "HTTP Method is not supported: " + asciiString(httpMethod), extraHeaders); return; } @@ -947,13 +951,13 @@ public void settings(boolean clearPrevious, Settings settings) { @Override public void ping(boolean ack, int payload1, int payload2) { - if (!keepAliveEnforcer.pingAcceptable()) { - abruptShutdown(ErrorCode.ENHANCE_YOUR_CALM, "too_many_pings", - Status.RESOURCE_EXHAUSTED.withDescription("Too many pings from client"), false); - return; - } long payload = (((long) payload1) << 32) | (payload2 & 0xffffffffL); if (!ack) { + if (!keepAliveEnforcer.pingAcceptable()) { + abruptShutdown(ErrorCode.ENHANCE_YOUR_CALM, "too_many_pings", + Status.RESOURCE_EXHAUSTED.withDescription("Too many pings from client"), false); + return; + } frameLogger.logPing(OkHttpFrameLogger.Direction.INBOUND, payload); synchronized (lock) { frameWriter.ping(true, payload1, payload2); @@ -1066,11 +1070,19 @@ private void streamError(int streamId, ErrorCode errorCode, String reason) { private void respondWithHttpError( int streamId, boolean inFinished, int httpCode, Status.Code statusCode, String msg) { + respondWithHttpError(streamId, inFinished, httpCode, statusCode, msg, + Collections.emptyList()); + } + + private void respondWithHttpError( + int streamId, boolean inFinished, int httpCode, Status.Code statusCode, String msg, + List
extraHeaders) { Metadata metadata = new Metadata(); metadata.put(InternalStatus.CODE_KEY, statusCode.toStatus()); metadata.put(InternalStatus.MESSAGE_KEY, msg); List
headers = Headers.createHttpResponseHeaders(httpCode, "text/plain; charset=utf-8", metadata); + headers.addAll(extraHeaders); Buffer data = new Buffer().writeUtf8(msg); synchronized (lock) { diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpTlsUpgrader.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpTlsUpgrader.java index 1004dcd93f9..a8b038c91f4 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpTlsUpgrader.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpTlsUpgrader.java @@ -19,13 +19,13 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import io.grpc.okhttp.internal.ConnectionSpec; -import io.grpc.okhttp.internal.OkHostnameVerifier; import io.grpc.okhttp.internal.Protocol; import java.io.IOException; import java.net.Socket; import java.util.Arrays; import java.util.Collections; import java.util.List; +import javax.annotation.Nonnull; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSocket; @@ -52,7 +52,7 @@ final class OkHttpTlsUpgrader { * @throws RuntimeException if the upgrade negotiation failed. */ public static SSLSocket upgrade(SSLSocketFactory sslSocketFactory, - HostnameVerifier hostnameVerifier, Socket socket, String host, int port, + @Nonnull HostnameVerifier hostnameVerifier, Socket socket, String host, int port, ConnectionSpec spec) throws IOException { Preconditions.checkNotNull(sslSocketFactory, "sslSocketFactory"); Preconditions.checkNotNull(socket, "socket"); @@ -67,9 +67,6 @@ public static SSLSocket upgrade(SSLSocketFactory sslSocketFactory, "Only " + TLS_PROTOCOLS + " are supported, but negotiated protocol is %s", negotiatedProtocol); - if (hostnameVerifier == null) { - hostnameVerifier = OkHostnameVerifier.INSTANCE; - } if (!hostnameVerifier.verify(canonicalizeHost(host), sslSocket.getSession())) { throw new SSLPeerUnverifiedException("Cannot verify hostname: " + host); } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpWritableBufferAllocator.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpWritableBufferAllocator.java index 481ada61c96..58896a5dbb0 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpWritableBufferAllocator.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpWritableBufferAllocator.java @@ -27,11 +27,9 @@ */ class OkHttpWritableBufferAllocator implements WritableBufferAllocator { - // Use 4k as our minimum buffer size. - private static final int MIN_BUFFER = 4096; - // Set the maximum buffer size to 1MB private static final int MAX_BUFFER = 1024 * 1024; + public static final int SEGMENT_SIZE_COPY = 8192; // Should equal Segment.SIZE /** * Construct a new instance. @@ -45,7 +43,9 @@ class OkHttpWritableBufferAllocator implements WritableBufferAllocator { */ @Override public WritableBuffer allocate(int capacityHint) { - capacityHint = Math.min(MAX_BUFFER, Math.max(MIN_BUFFER, capacityHint)); + // okio buffer uses fixed size Segments, round capacityHint up + capacityHint = Math.min(MAX_BUFFER, + (capacityHint + SEGMENT_SIZE_COPY - 1) / SEGMENT_SIZE_COPY * SEGMENT_SIZE_COPY); return new OkHttpWritableBuffer(new Buffer(), capacityHint); } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/Utils.java b/okhttp/src/main/java/io/grpc/okhttp/Utils.java index 2dc5f1e1ec9..4546143cf3b 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/Utils.java +++ b/okhttp/src/main/java/io/grpc/okhttp/Utils.java @@ -17,6 +17,7 @@ package io.grpc.okhttp; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.InternalChannelz; import io.grpc.InternalMetadata; import io.grpc.Metadata; @@ -29,7 +30,6 @@ import java.util.List; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; /** * Common utility methods for OkHttp transport. diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java index 3670cd057c1..89d37536b70 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; import com.google.common.util.concurrent.SettableFuture; @@ -34,6 +35,7 @@ import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; import io.grpc.TlsChannelCredentials; +import io.grpc.internal.CertificateUtils; import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; import io.grpc.internal.FakeClock; @@ -56,7 +58,6 @@ import javax.security.auth.x500.X500Principal; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -66,8 +67,6 @@ @RunWith(JUnit4.class) public class OkHttpChannelBuilderTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); @Test @@ -99,10 +98,9 @@ private void overrideAuthorityIsReadableHelper(OkHttpChannelBuilder builder, @Test public void failOverrideInvalidAuthority() { OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("good", 1234); - - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority:"); - builder.overrideAuthority("[invalidauthority"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.overrideAuthority("[invalidauthority")); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [invalidauthority"); } @Test @@ -118,17 +116,16 @@ public void enableCheckAuthorityFailOverrideInvalidAuthority() { .disableCheckAuthority() .enableCheckAuthority(); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority:"); - builder.overrideAuthority("[invalidauthority"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.overrideAuthority("[invalidauthority")); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [invalidauthority"); } @Test public void failInvalidAuthority() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid host or port"); - - OkHttpChannelBuilder.forAddress("invalid_authority", 1234); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> OkHttpChannelBuilder.forAddress("invalid_authority", 1234)); + assertThat(e.getMessage()).isEqualTo("Invalid host or port: invalid_authority 1234"); } @Test @@ -212,7 +209,7 @@ public void sslSocketFactoryFrom_tls_mtls() throws Exception { TrustManager[] trustManagers; try (InputStream ca = TlsTesting.loadCert("ca.pem")) { - trustManagers = OkHttpChannelBuilder.createTrustManager(ca); + trustManagers = CertificateUtils.createTrustManager(ca); } SSLContext serverContext = SSLContext.getInstance("TLS"); @@ -257,7 +254,7 @@ public void sslSocketFactoryFrom_tls_mtls_keyFile() throws Exception { InputStream ca = TlsTesting.loadCert("ca.pem")) { serverContext.init( OkHttpChannelBuilder.createKeyManager(server1Chain, server1Key), - OkHttpChannelBuilder.createTrustManager(ca), + CertificateUtils.createTrustManager(ca), null); } final SSLServerSocket serverListenSocket = @@ -395,10 +392,10 @@ public ChannelCredentials withoutBearerTokens() { @Test public void failForUsingClearTextSpecDirectly() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("plaintext ConnectionSpec is not accepted"); - - OkHttpChannelBuilder.forAddress("host", 1234).connectionSpec(ConnectionSpec.CLEARTEXT); + OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("host", 1234); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.connectionSpec(ConnectionSpec.CLEARTEXT)); + assertThat(e).hasMessageThat().isEqualTo("plaintext ConnectionSpec is not accepted"); } @Test diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java index 1f716705968..1c98d6ee30d 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java @@ -20,6 +20,7 @@ import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.times; @@ -244,12 +245,13 @@ public void getUnaryRequest() throws IOException { // GET streams send headers after halfClose is called. verify(mockedFrameWriter, times(0)).synStream( eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); - verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class)); + verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class), + isA(String.class)); byte[] msg = "request".getBytes(Charset.forName("UTF-8")); stream.writeMessage(new ByteArrayInputStream(msg)); stream.halfClose(); - verify(transport).streamReadyToStart(eq(stream)); + verify(transport).streamReadyToStart(eq(stream), any(String.class)); stream.transportState().start(3); verify(mockedFrameWriter) diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index 987cc09203e..f87912c44ea 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -67,13 +67,16 @@ import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; import io.grpc.Status.Code; -import io.grpc.StatusException; import io.grpc.internal.AbstractStream; +import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientTransport; +import io.grpc.internal.DisconnectError; import io.grpc.internal.FakeClock; +import io.grpc.internal.GoAwayDisconnectError; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.okhttp.OkHttpClientTransport.ClientFrameHandler; import io.grpc.okhttp.OkHttpFrameLogger.Direction; import io.grpc.okhttp.internal.Protocol; @@ -116,6 +119,10 @@ import java.util.logging.Logger; import javax.annotation.Nullable; import javax.net.SocketFactory; +import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; import okio.Buffer; import okio.BufferedSink; import okio.BufferedSource; @@ -190,16 +197,24 @@ public void tearDown() { private void initTransport() throws Exception { startTransport( - DEFAULT_START_STREAM_ID, null, true, null); + DEFAULT_START_STREAM_ID, null, true, null, null); } private void initTransport(int startId) throws Exception { - startTransport(startId, null, true, null); + startTransport(startId, null, true, null, null); } private void startTransport(int startId, @Nullable Runnable connectingCallback, - boolean waitingForConnected, String userAgent) - throws Exception { + boolean waitingForConnected, String userAgent, + HostnameVerifier hostnameVerifier) throws Exception { + startTransport(startId, connectingCallback, waitingForConnected, userAgent, hostnameVerifier, + false); + } + + private void startTransport(int startId, @Nullable Runnable connectingCallback, + boolean waitingForConnected, String userAgent, + HostnameVerifier hostnameVerifier, boolean useSslSocket) + throws Exception { connectedFuture = SettableFuture.create(); final Ticker ticker = new Ticker() { @Override @@ -213,7 +228,11 @@ public Stopwatch get() { return Stopwatch.createUnstarted(ticker); } }; - channelBuilder.socketFactory(new FakeSocketFactory(socket)); + channelBuilder.socketFactory( + new FakeSocketFactory(useSslSocket ? new MockSslSocket(socket) : socket)); + if (hostnameVerifier != null) { + channelBuilder = channelBuilder.hostnameVerifier(hostnameVerifier); + } clientTransport = new OkHttpClientTransport( channelBuilder.buildTransportFactory(), userAgent, @@ -241,12 +260,37 @@ public void testToString() throws Exception { /*userAgent=*/ null, EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); String s = clientTransport.toString(); assertTrue("Unexpected: " + s, s.contains("OkHttpClientTransport")); assertTrue("Unexpected: " + s, s.contains(address.toString())); } + @Test + public void testTransportExecutorWithTooFewThreads() throws Exception { + ExecutorService fixedPoolExecutor = Executors.newFixedThreadPool(1); + channelBuilder.transportExecutor(fixedPoolExecutor); + InetSocketAddress address = InetSocketAddress.createUnresolved("hostname", 31415); + clientTransport = new OkHttpClientTransport( + channelBuilder.buildTransportFactory(), + address, + "hostname", + null, + EAG_ATTRS, + NO_PROXY, + tooManyPingsRunnable, + null); + clientTransport.start(transportListener); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture(), + eq(new GoAwayDisconnectError(GrpcUtil.Http2Error.INTERNAL_ERROR))); + Status capturedStatus = statusCaptor.getValue(); + assertEquals("Timed out waiting for second handshake thread. " + + "The transport executor pool may have run out of threads", + capturedStatus.getDescription()); + } + /** * Test logging is functioning correctly for client received Http/2 frames. Not intended to test * actual frame content being logged. @@ -278,7 +322,7 @@ public void close() throws SecurityException { assertThat(log.getLevel()).isEqualTo(Level.FINE); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -368,7 +412,7 @@ public void maxMessageSizeShouldBeEnforced() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -421,11 +465,11 @@ public void nextFrameThrowIoException() throws Exception { initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); @@ -441,7 +485,8 @@ public void nextFrameThrowIoException() throws Exception { assertEquals(NETWORK_ISSUE_MESSAGE, listener1.status.getCause().getMessage()); assertEquals(Status.INTERNAL.getCode(), listener2.status.getCode()); assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage()); - verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class)); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -455,7 +500,7 @@ public void nextFrameThrowIoException() throws Exception { public void nextFrameThrowsError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -467,7 +512,8 @@ public void nextFrameThrowsError() throws Exception { assertEquals(0, activeStreamCount()); assertEquals(Status.INTERNAL.getCode(), listener.status.getCode()); assertEquals(ERROR_MESSAGE, listener.status.getCause().getMessage()); - verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class)); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -476,14 +522,15 @@ public void nextFrameThrowsError() throws Exception { public void nextFrameReturnFalse() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); frameReader.nextFrameAtEndOfStream(); listener.waitUntilStreamClosed(); assertEquals(Status.UNAVAILABLE.getCode(), listener.status.getCode()); - verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class)); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -494,7 +541,7 @@ public void readMessages() throws Exception { final int numMessages = 10; final String message = "Hello Client"; MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(numMessages); @@ -525,7 +572,8 @@ public void receivedHeadersForInvalidStreamShouldKillConnection() throws Excepti HeadersMode.HTTP_20_HEADERS); verify(frameWriter, timeout(TIME_OUT_MS)) .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); - verify(transportListener).transportShutdown(isA(Status.class)); + verify(transportListener).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -537,7 +585,8 @@ public void receivedDataForInvalidStreamShouldKillConnection() throws Exception 1000, 1000); verify(frameWriter, timeout(TIME_OUT_MS)) .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); - verify(transportListener).transportShutdown(isA(Status.class)); + verify(transportListener).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -546,7 +595,7 @@ public void receivedDataForInvalidStreamShouldKillConnection() throws Exception public void invalidInboundHeadersCancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -571,7 +620,7 @@ public void invalidInboundHeadersCancelStream() throws Exception { public void invalidInboundTrailersPropagateToMetadata() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -591,7 +640,7 @@ public void invalidInboundTrailersPropagateToMetadata() throws Exception { public void readStatus() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); @@ -605,7 +654,7 @@ public void readStatus() throws Exception { public void receiveReset() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); @@ -622,7 +671,7 @@ public void receiveReset() throws Exception { public void receiveResetNoError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); @@ -643,7 +692,7 @@ public void receiveResetNoError() throws Exception { public void cancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.CANCELLED); @@ -658,7 +707,7 @@ public void cancelStream() throws Exception { public void addDefaultUserAgent() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); Header userAgentHeader = new Header(GrpcUtil.USER_AGENT_KEY.name(), @@ -675,9 +724,9 @@ public void addDefaultUserAgent() throws Exception { @Test public void overrideDefaultUserAgent() throws Exception { - startTransport(3, null, true, "fakeUserAgent"); + startTransport(3, null, true, "fakeUserAgent", null); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); List
expectedHeaders = Arrays.asList(HTTP_SCHEME_HEADER, METHOD_HEADER, @@ -696,7 +745,7 @@ public void overrideDefaultUserAgent() throws Exception { public void cancelStreamForDeadlineExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.DEADLINE_EXCEEDED); @@ -710,7 +759,7 @@ public void writeMessage() throws Exception { initTransport(); final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); @@ -725,6 +774,65 @@ public void writeMessage() throws Exception { shutdownAndVerify(); } + @Test + public void perRpcAuthoritySpecified_verificationSkippedInPlainTextConnection() + throws Exception { + initTransport(); + final String message = "Hello Server"; + MockStreamListener listener = new MockStreamListener(); + ClientStream stream = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + stream.start(listener); + InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); + assertEquals(12, input.available()); + stream.writeMessage(input); + stream.flush(); + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(3), any(Buffer.class), eq(12 + HEADER_LENGTH)); + Buffer sentFrame = capturedBuffer.poll(); + assertEquals(createMessageFrame(message), sentFrame); + stream.cancel(Status.CANCELLED); + shutdownAndVerify(); + } + + @Test + public void perRpcAuthoritySpecified_hostnameVerification_ignoredForNonSslSocket() + throws Exception { + startTransport( + DEFAULT_START_STREAM_ID, null, true, null, + (hostname, session) -> false, false); + ClientStream unused = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + shutdownAndVerify(); + } + + @Test + public void perRpcAuthoritySpecified_hostnameVerification_SslSocket_successCase() + throws Exception { + startTransport( + DEFAULT_START_STREAM_ID, null, true, null, + (hostname, session) -> true, true); + ClientStream unused = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + shutdownAndVerify(); + } + + @Test + public void perRpcAuthoritySpecified_hostnameVerification_SslSocket_flagDisabled() + throws Exception { + startTransport( + DEFAULT_START_STREAM_ID, null, true, null, + (hostname, session) -> false, true); + ClientStream clientStream = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + assertThat(clientStream).isInstanceOf(OkHttpClientStream.class); + shutdownAndVerify(); + } + @Test public void transportTracer_windowSizeDefault() throws Exception { initTransport(); @@ -751,12 +859,12 @@ public void windowUpdate() throws Exception { initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(2); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(2); @@ -821,7 +929,7 @@ public void windowUpdate() throws Exception { public void windowUpdateWithInboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = INITIAL_WINDOW_SIZE / 2 + 1; @@ -858,7 +966,7 @@ public void windowUpdateWithInboundFlowControl() throws Exception { public void outboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); @@ -904,7 +1012,7 @@ public void outboundFlowControl_smallWindowSize() throws Exception { setInitialWindowSize(initialOutboundWindowSize); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); @@ -947,7 +1055,7 @@ public void outboundFlowControl_bigWindowSize() throws Exception { frameHandler().windowUpdate(0, 65535); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); @@ -983,7 +1091,7 @@ public void outboundFlowControl_bigWindowSize() throws Exception { public void outboundFlowControlWithInitialWindowSizeChange() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; @@ -1029,7 +1137,7 @@ public void outboundFlowControlWithInitialWindowSizeChange() throws Exception { public void outboundFlowControlWithInitialWindowSizeChangeInMiddleOfStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; @@ -1064,17 +1172,18 @@ public void stopNormally() throws Exception { initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); assertEquals(2, activeStreamCount()); clientTransport.shutdown(SHUTDOWN_REASON); assertEquals(2, activeStreamCount()); - verify(transportListener).transportShutdown(same(SHUTDOWN_REASON)); + verify(transportListener).transportShutdown(same(SHUTDOWN_REASON), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); stream1.cancel(Status.CANCELLED); stream2.cancel(Status.CANCELLED); @@ -1094,11 +1203,11 @@ public void receiveGoAway() throws Exception { // start 2 streams. MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); @@ -1108,7 +1217,8 @@ public void receiveGoAway() throws Exception { frameHandler().goAway(3, ErrorCode.CANCEL, ByteString.EMPTY); // Transport should be in STOPPING state. - verify(transportListener).transportShutdown(isA(Status.class)); + verify(transportListener).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, never()).transportTerminated(); // Stream 2 should be closed. @@ -1121,7 +1231,7 @@ public void receiveGoAway() throws Exception { // But stream 1 should be able to send. final String sentMessage = "Should I also go away?"; - OkHttpClientStream stream = getStream(3); + ClientStream stream = getStream(3); InputStream input = new ByteArrayInputStream(sentMessage.getBytes(UTF_8)); assertEquals(22, input.available()); stream.writeMessage(input); @@ -1153,7 +1263,7 @@ public void streamIdExhausted() throws Exception { initTransport(startId); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1178,7 +1288,8 @@ public void streamIdExhausted() throws Exception { // Should only have the first message delivered. assertEquals(message, listener.messages.get(0)); verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(startId), eq(ErrorCode.CANCEL)); - verify(transportListener).transportShutdown(isA(Status.class)); + verify(transportListener).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -1189,11 +1300,11 @@ public void pendingStreamSucceed() throws Exception { setMaxConcurrentStreams(1); final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); String sentMessage = "hello"; @@ -1226,7 +1337,7 @@ public void pendingStreamCancelled() throws Exception { initTransport(); setMaxConcurrentStreams(0); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); @@ -1245,11 +1356,11 @@ public void pendingStreamFailedByGoAway() throws Exception { setMaxConcurrentStreams(1); final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); @@ -1275,7 +1386,7 @@ public void pendingStreamSucceedAfterShutdown() throws Exception { setMaxConcurrentStreams(0); final MockStreamListener listener = new MockStreamListener(); // The second stream should be pending. - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); @@ -1299,15 +1410,15 @@ public void pendingStreamFailedByIdExhausted() throws Exception { final MockStreamListener listener2 = new MockStreamListener(); final MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second and third stream should be pending. - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); @@ -1331,7 +1442,7 @@ public void pendingStreamFailedByIdExhausted() throws Exception { public void receivingWindowExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1383,7 +1494,7 @@ public void duplexStreamingHeadersShouldNotBeFlushed() throws Exception { private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); verify(frameWriter, timeout(TIME_OUT_MS)).synStream( @@ -1400,7 +1511,7 @@ private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception { public void receiveDataWithoutHeader() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1423,7 +1534,7 @@ public void receiveDataWithoutHeader() throws Exception { public void receiveDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1447,7 +1558,7 @@ public void receiveDataWithoutHeaderAndTrailer() throws Exception { public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1469,7 +1580,7 @@ public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); @@ -1489,7 +1600,8 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception (int) buffer.size()); verify(frameWriter, timeout(TIME_OUT_MS)) .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); - verify(transportListener).transportShutdown(isA(Status.class)); + verify(transportListener).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -1498,7 +1610,7 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception public void receiveWindowUpdateForUnknownStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); @@ -1509,7 +1621,8 @@ public void receiveWindowUpdateForUnknownStream() throws Exception { frameHandler().windowUpdate(5, 73); verify(frameWriter, timeout(TIME_OUT_MS)) .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); - verify(transportListener).transportShutdown(isA(Status.class)); + verify(transportListener).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -1518,7 +1631,7 @@ public void receiveWindowUpdateForUnknownStream() throws Exception { public void shouldBeInitiallyReady() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); @@ -1536,7 +1649,7 @@ public void notifyOnReady() throws Exception { AbstractStream.TransportState.DEFAULT_ONREADY_THRESHOLD - HEADER_LENGTH - 1; setInitialWindowSize(0); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); @@ -1642,16 +1755,14 @@ public void ping_failsWhenTransportShutdown() throws Exception { clientTransport.shutdown(SHUTDOWN_REASON); // ping failed on channel shutdown assertEquals(1, callback.invocationCount); - assertTrue(callback.failureCause instanceof StatusException); - assertSame(SHUTDOWN_REASON, ((StatusException) callback.failureCause).getStatus()); + assertSame(SHUTDOWN_REASON, callback.failureCause); // now that handler is in terminal state, all future pings fail immediately callback = new PingCallbackImpl(); clientTransport.ping(callback, MoreExecutors.directExecutor()); assertEquals(1, getTransportStats(clientTransport).keepAlivesSent); assertEquals(1, callback.invocationCount); - assertTrue(callback.failureCause instanceof StatusException); - assertSame(SHUTDOWN_REASON, ((StatusException) callback.failureCause).getStatus()); + assertSame(SHUTDOWN_REASON, callback.failureCause); shutdownAndVerify(); } @@ -1666,18 +1777,14 @@ public void ping_failsIfTransportFails() throws Exception { clientTransport.onException(new IOException()); // ping failed on error assertEquals(1, callback.invocationCount); - assertTrue(callback.failureCause instanceof StatusException); - assertEquals(Status.Code.UNAVAILABLE, - ((StatusException) callback.failureCause).getStatus().getCode()); + assertEquals(Status.Code.UNAVAILABLE, callback.failureCause.getCode()); // now that handler is in terminal state, all future pings fail immediately callback = new PingCallbackImpl(); clientTransport.ping(callback, MoreExecutors.directExecutor()); assertEquals(1, getTransportStats(clientTransport).keepAlivesSent); assertEquals(1, callback.invocationCount); - assertTrue(callback.failureCause instanceof StatusException); - assertEquals(Status.Code.UNAVAILABLE, - ((StatusException) callback.failureCause).getStatus().getCode()); + assertEquals(Status.Code.UNAVAILABLE, callback.failureCause.getCode()); shutdownAndVerify(); } @@ -1689,7 +1796,7 @@ public void shutdownDuringConnecting() throws Exception { DEFAULT_START_STREAM_ID, connectingCallback, false, - null); + null, null); clientTransport.shutdown(SHUTDOWN_REASON); delayed.set(null); shutdownAndVerify(); @@ -1704,7 +1811,8 @@ public void invalidAuthorityPropagates() { "userAgent", EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); String host = clientTransport.getOverridenHost(); int port = clientTransport.getOverridenPort(); @@ -1722,13 +1830,15 @@ public void unreachableServer() throws Exception { "userAgent", EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); clientTransport.start(listener); - ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); - verify(listener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture()); - Status status = captor.getValue(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(listener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture(), + eq(new GoAwayDisconnectError(GrpcUtil.Http2Error.INTERNAL_ERROR))); + Status status = statusCaptor.getValue(); assertEquals(Status.UNAVAILABLE.getCode(), status.getCode()); assertTrue(status.getCause().toString(), status.getCause() instanceof IOException); @@ -1752,13 +1862,15 @@ public void customSocketFactory() throws Exception { "userAgent", EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); clientTransport.start(listener); - ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); - verify(listener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture()); - Status status = captor.getValue(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(listener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture(), + eq(new GoAwayDisconnectError(GrpcUtil.Http2Error.INTERNAL_ERROR))); + Status status = statusCaptor.getValue(); assertEquals(Status.UNAVAILABLE.getCode(), status.getCode()); assertSame(exception, status.getCause()); } @@ -1777,7 +1889,8 @@ public void proxy_200() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -1806,7 +1919,8 @@ public void proxy_200() throws Exception { }); sock.getOutputStream().flush(); - verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class)); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class), + any(DisconnectError.class)); while (sock.getInputStream().read() != -1) {} verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); sock.close(); @@ -1826,7 +1940,8 @@ public void proxy_500() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -1845,17 +1960,18 @@ public void proxy_500() throws Exception { assertEquals(-1, sock.getInputStream().read()); - ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); - verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture()); - Status error = captor.getValue(); - assertTrue("Status didn't contain error code: " + captor.getValue(), - error.getDescription().contains("500")); - assertTrue("Status didn't contain error description: " + captor.getValue(), - error.getDescription().contains("OH NO")); - assertTrue("Status didn't contain error text: " + captor.getValue(), - error.getDescription().contains(errorText)); - assertEquals("Not UNAVAILABLE: " + captor.getValue(), - Status.UNAVAILABLE.getCode(), error.getCode()); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture(), + eq(new GoAwayDisconnectError(GrpcUtil.Http2Error.INTERNAL_ERROR))); + Status status = statusCaptor.getValue(); + assertTrue("Status didn't contain error code: " + statusCaptor.getValue(), + status.getDescription().contains("500")); + assertTrue("Status didn't contain error description: " + statusCaptor.getValue(), + status.getDescription().contains("OH NO")); + assertTrue("Status didn't contain error text: " + statusCaptor.getValue(), + status.getDescription().contains(errorText)); + assertEquals("Not UNAVAILABLE: " + statusCaptor.getValue(), + Status.UNAVAILABLE.getCode(), status.getCode()); sock.close(); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } @@ -1874,20 +1990,22 @@ public void proxy_immediateServerClose() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); serverSocket.close(); sock.close(); - ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); - verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture()); - Status error = captor.getValue(); - assertTrue("Status didn't contain proxy: " + captor.getValue(), - error.getDescription().contains("proxy")); - assertEquals("Not UNAVAILABLE: " + captor.getValue(), - Status.UNAVAILABLE.getCode(), error.getCode()); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture(), + eq(new GoAwayDisconnectError(GrpcUtil.Http2Error.INTERNAL_ERROR))); + Status status = statusCaptor.getValue(); + assertTrue("Status didn't contain proxy: " + statusCaptor.getValue(), + status.getDescription().contains("proxy")); + assertEquals("Not UNAVAILABLE: " + statusCaptor.getValue(), + Status.UNAVAILABLE.getCode(), status.getCode()); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } @@ -1905,7 +2023,8 @@ public void proxy_serverHangs() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.proxySocketTimeout = 10; clientTransport.start(transportListener); @@ -1917,7 +2036,8 @@ public void proxy_serverHangs() throws Exception { assertEquals("Host: theservice:80", reader.readLine()); while (!"".equals(reader.readLine())) {} - verify(transportListener, timeout(200)).transportShutdown(any(Status.class)); + verify(transportListener, timeout(200)).transportShutdown(any(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); sock.close(); } @@ -1972,13 +2092,13 @@ public void goAway_streamListenerRpcProgress() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(1); @@ -2012,13 +2132,13 @@ public void reset_streamListenerRpcProgress() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); @@ -2054,13 +2174,13 @@ public void shutdownNow_streamListenerRpcProgress() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(1); @@ -2085,11 +2205,11 @@ public void finishedStreamRemovedFromInUseState() throws Exception { initTransport(); setMaxConcurrentStreams(1); final MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); + OkHttpClientStream stream = clientTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); - OkHttpClientStream pendingStream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); + OkHttpClientStream pendingStream = clientTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); pendingStream.start(listener); waitForStreamPending(1); clientTransport.finishStream(stream.transportState().id(), Status.OK, PROCESSED, @@ -2129,7 +2249,7 @@ private void waitForStreamPending(int expected) throws Exception { private void assertNewStreamFail() throws Exception { MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); listener.waitUntilStreamClosed(); @@ -2360,10 +2480,128 @@ public InputStream getInputStream() { } } + private static class MockSslSocket extends SSLSocket { + private Socket delegate; + + MockSslSocket(Socket socket) { + delegate = socket; + } + + @Override + public String[] getSupportedCipherSuites() { + return new String[0]; + } + + @Override + public String[] getEnabledCipherSuites() { + return new String[0]; + } + + @Override + public void setEnabledCipherSuites(String[] suites) { + + } + + @Override + public String[] getSupportedProtocols() { + return new String[0]; + } + + @Override + public String[] getEnabledProtocols() { + return new String[0]; + } + + @Override + public void setEnabledProtocols(String[] protocols) { + + } + + @Override + public SSLSession getSession() { + return null; + } + + @Override + public void addHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void startHandshake() throws IOException { + + } + + @Override + public void setUseClientMode(boolean mode) { + + } + + @Override + public boolean getUseClientMode() { + return false; + } + + @Override + public void setNeedClientAuth(boolean need) { + + } + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean want) { + + } + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean flag) { + + } + + @Override + public boolean getEnableSessionCreation() { + return false; + } + + @Override + public synchronized void close() throws IOException { + delegate.close(); + } + + @Override + public SocketAddress getLocalSocketAddress() { + return delegate.getLocalSocketAddress(); + } + + @Override + public OutputStream getOutputStream() throws IOException { + return delegate.getOutputStream(); + } + + @Override + public InputStream getInputStream() throws IOException { + return delegate.getInputStream(); + } + } + static class PingCallbackImpl implements ClientTransport.PingCallback { int invocationCount; long roundTripTime; - Throwable failureCause; + Status failureCause; @Override public void onSuccess(long roundTripTimeNanos) { @@ -2372,7 +2610,7 @@ public void onSuccess(long roundTripTimeNanos) { } @Override - public void onFailure(Throwable cause) { + public void onFailure(Status cause) { invocationCount++; this.failureCause = cause; } diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpProtocolNegotiatorTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpProtocolNegotiatorTest.java index cc9f30862af..4353dc2597b 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpProtocolNegotiatorTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpProtocolNegotiatorTest.java @@ -16,10 +16,12 @@ package io.grpc.okhttp; +import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -37,9 +39,7 @@ import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentMatchers; @@ -49,9 +49,6 @@ */ @RunWith(JUnit4.class) public class OkHttpProtocolNegotiatorTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); - private final SSLSocket sock = mock(SSLSocket.class); private final Platform platform = mock(Platform.class); @@ -118,21 +115,19 @@ public void negotiate_handshakeFails() throws IOException { OkHttpProtocolNegotiator negotiator = OkHttpProtocolNegotiator.get(); doReturn(parameters).when(sock).getSSLParameters(); doThrow(new IOException()).when(sock).startHandshake(); - thrown.expect(IOException.class); - - negotiator.negotiate(sock, "hostname", ImmutableList.of(Protocol.HTTP_2)); + assertThrows(IOException.class, + () -> negotiator.negotiate(sock, "hostname", ImmutableList.of(Protocol.HTTP_2))); } @Test - public void negotiate_noSelectedProtocol() throws Exception { + public void negotiate_noSelectedProtocol() { Platform platform = mock(Platform.class); OkHttpProtocolNegotiator negotiator = new OkHttpProtocolNegotiator(platform); - thrown.expect(RuntimeException.class); - thrown.expectMessage("TLS ALPN negotiation failed"); - - negotiator.negotiate(sock, "hostname", ImmutableList.of(Protocol.HTTP_2)); + RuntimeException e = assertThrows(RuntimeException.class, + () -> negotiator.negotiate(sock, "hostname", ImmutableList.of(Protocol.HTTP_2))); + assertThat(e).hasMessageThat().isEqualTo("TLS ALPN negotiation failed with protocols: [h2]"); } @Test @@ -150,7 +145,7 @@ public void negotiate_success() throws Exception { // Checks that the super class is properly invoked. @Test - public void negotiate_android_handshakeFails() throws Exception { + public void negotiate_android_handshakeFails() { when(platform.getTlsExtensionType()).thenReturn(TlsExtensionType.ALPN_AND_NPN); AndroidNegotiator negotiator = new AndroidNegotiator(platform); @@ -161,10 +156,9 @@ public void startHandshake() throws IOException { } }; - thrown.expect(IOException.class); - thrown.expectMessage("expected"); - - negotiator.negotiate(androidSock, "hostname", ImmutableList.of(Protocol.HTTP_2)); + IOException e = assertThrows(IOException.class, + () -> negotiator.negotiate(androidSock, "hostname", ImmutableList.of(Protocol.HTTP_2))); + assertThat(e).hasMessageThat().isEqualTo("expected"); } @VisibleForTesting diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java index 4aeeae2fa8b..be8dbf0e62b 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java @@ -44,18 +44,6 @@ public void setup() { } } - @Override - @Test - public void readToByteBufferShouldSucceed() { - // Not supported. - } - - @Override - @Test - public void partialReadToByteBufferShouldSucceed() { - // Not supported. - } - @Override @Test public void markAndResetWithReadShouldSucceed() { diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java index d64d314d7d8..00db6e1d339 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java @@ -34,6 +34,7 @@ import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; +import com.google.common.collect.Lists; import com.google.common.io.ByteStreams; import io.grpc.Attributes; import io.grpc.InternalChannelz.SocketStats; @@ -62,6 +63,7 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Deque; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -919,8 +921,9 @@ public void httpGet_failsWith405() throws Exception { CONTENT_TYPE_HEADER, TE_HEADER)); clientFrameWriter.flush(); - - verifyHttpError(1, 405, Status.Code.INTERNAL, "HTTP Method is not supported: GET"); + List
extraHeaders = Lists.newArrayList(new Header("allow", "POST")); + verifyHttpError(1, 405, Status.Code.INTERNAL, "HTTP Method is not supported: GET", + extraHeaders); shutdownAndTerminate(/*lastStreamId=*/ 1); } @@ -976,7 +979,8 @@ public void httpErrorsAdhereToFlowControl() throws Exception { new Header(":status", "405"), new Header("content-type", "text/plain; charset=utf-8"), new Header("grpc-status", "" + Status.Code.INTERNAL.value()), - new Header("grpc-message", errorDescription)); + new Header("grpc-message", errorDescription), + new Header("allow", "POST")); assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); verify(clientFramesRead) .headers(false, false, 1, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); @@ -1264,6 +1268,31 @@ public void keepAliveEnforcer_noticesActive() throws Exception { eq(ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII))); } + @Test + public void keepAliveEnforcer_doesNotEnforcePingAcks() throws Exception { + serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS) + .permitKeepAliveWithoutCalls(true); + initTransport(); + handshake(); + + for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES + 2; i++) { + int serverPingId = 0xDEAD + i; + clientFrameWriter.ping(true, serverPingId, 0); + clientFrameWriter.flush(); + } + + for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) { + pingPong(); + } + + pingPongId++; + clientFrameWriter.ping(false, pingPongId, 0); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM, + ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII)); + } + @Test public void maxConcurrentCallsPerConnection_failsWithRst() throws Exception { int maxConcurrentCallsPerConnection = 1; @@ -1398,11 +1427,18 @@ private void pingPong() throws IOException { private void verifyHttpError( int streamId, int httpCode, Status.Code grpcCode, String errorDescription) throws Exception { - List
responseHeaders = Arrays.asList( + verifyHttpError(streamId, httpCode, grpcCode, errorDescription, Collections.emptyList()); + } + + private void verifyHttpError( + int streamId, int httpCode, Status.Code grpcCode, String errorDescription, + List
extraHeaders) throws Exception { + List
responseHeaders = Lists.newArrayList( new Header(":status", "" + httpCode), new Header("content-type", "text/plain; charset=utf-8"), new Header("grpc-status", "" + grpcCode.value()), new Header("grpc-message", errorDescription)); + responseHeaders.addAll(extraHeaders); assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); verify(clientFramesRead) .headers(false, false, streamId, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpWritableBufferAllocatorTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpWritableBufferAllocatorTest.java index e606b6b9a50..c19224822a8 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpWritableBufferAllocatorTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpWritableBufferAllocatorTest.java @@ -16,11 +16,13 @@ package io.grpc.okhttp; +import static io.grpc.okhttp.OkHttpWritableBufferAllocator.SEGMENT_SIZE_COPY; import static org.junit.Assert.assertEquals; import io.grpc.internal.WritableBuffer; import io.grpc.internal.WritableBufferAllocator; import io.grpc.internal.WritableBufferAllocatorTestBase; +import okio.Segment; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -38,11 +40,12 @@ protected WritableBufferAllocator allocator() { return allocator; } + @SuppressWarnings("KotlinInternal") @Test public void testCapacity() { WritableBuffer buffer = allocator().allocate(4096); assertEquals(0, buffer.readableBytes()); - assertEquals(4096, buffer.writableBytes()); + assertEquals(SEGMENT_SIZE_COPY, buffer.writableBytes()); } @Test @@ -54,8 +57,14 @@ public void testInitialCapacityHasMaximum() { @Test public void testIsExactBelowMaxCapacity() { - WritableBuffer buffer = allocator().allocate(4097); + WritableBuffer buffer = allocator().allocate(SEGMENT_SIZE_COPY + 1); assertEquals(0, buffer.readableBytes()); - assertEquals(4097, buffer.writableBytes()); + assertEquals(SEGMENT_SIZE_COPY * 2, buffer.writableBytes()); + } + + @SuppressWarnings("KotlinInternal") + @Test + public void testSegmentSizeMatchesKotlin() { + assertEquals(Segment.SIZE, SEGMENT_SIZE_COPY); } } diff --git a/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java index a21360a89ba..20a2f1a5ca7 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java @@ -18,8 +18,10 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; +import static org.junit.Assert.fail; import com.google.common.base.Throwables; +import io.grpc.CallOptions; import io.grpc.ChannelCredentials; import io.grpc.ConnectivityState; import io.grpc.ManagedChannel; @@ -32,18 +34,34 @@ import io.grpc.TlsServerCredentials; import io.grpc.internal.testing.TestUtils; import io.grpc.okhttp.internal.Platform; +import io.grpc.stub.ClientCalls; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.TlsTesting; import io.grpc.testing.protobuf.SimpleRequest; import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.grpc.util.CertificateUtils; import java.io.IOException; import java.io.InputStream; +import java.net.Socket; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Optional; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; import org.junit.Assume; import org.junit.Before; import org.junit.Rule; @@ -53,6 +71,7 @@ /** Verify OkHttp's TLS integration. */ @RunWith(JUnit4.class) +@IgnoreJRERequirement public class TlsTest { @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); @@ -92,6 +111,325 @@ public void basicTls_succeeds() throws Exception { SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()); } + @Test + public void perRpcAuthorityOverride_hostnameVerifier_goodAuthority_succeeds() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("good.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_hostnameVerifier_badAuthority_fails() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("disallowed.name.com"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for hostname verifier failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getStatus().getDescription()).isEqualTo( + "HostNameVerifier verification failed for authority 'disallowed.name.com'"); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_hostnameVerifier_badAuthority_flagDisabled_succeeds() + throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("disallowed.name.com"), + SimpleRequest.getDefaultInstance()); + } + + @Test + public void perRpcAuthorityOverride_noTlsCredentialsUsedToBuildChannel_fails() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + SSLSocketFactory sslSocketFactory = TestUtils.newSslSocketFactoryForCa( + Platform.get().getProvider(), TestUtils.loadCert("ca.pem")); + ManagedChannel channel = grpcCleanupRule.register( + OkHttpChannelBuilder.forAddress("localhost", server.getPort()) + .overrideAuthority(TestUtils.TEST_SERVER_HOST) + .directExecutor() + .sslSocketFactory(sslSocketFactory) + .build()); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("bar.test.google.fr"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for authority verification failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getStatus().getDescription()).isEqualTo( + "Could not verify authority 'bar.test.google.fr' for the rpc with no " + + "X509TrustManager available"); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_trustManager_permitted_succeeds() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509ExtendedTrustManager regularTrustManager = + (X509ExtendedTrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new HostnameCheckingX509ExtendedTrustManager(regularTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("good.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_trustManager_denied_fails() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509ExtendedTrustManager regularTrustManager = + (X509ExtendedTrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new HostnameCheckingX509ExtendedTrustManager(regularTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("bad.test.google.fr"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for authority verification failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getCause().getCause()).isInstanceOf(CertificateException.class); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_trustManager_denied_flagDisabled_succeeds() + throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509ExtendedTrustManager regularTrustManager = + (X509ExtendedTrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new HostnameCheckingX509ExtendedTrustManager(regularTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("bad.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } + + /** + * This test simulates the absence of X509ExtendedTrustManager while still using the + * real trust manager for the connection handshake to happen. When the TrustManager is not an + * X509ExtendedTrustManager, the per-rpc check ignores the trust manager. However, the + * HostnameVerifier is still used, so only valid authorities are permitted. + */ + @Test + public void perRpcAuthorityOverride_notX509ExtendedTrustManager_goodAuthority_succeeds() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("foo.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_notX509ExtendedTrustManager_badAuthority_fails() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("disallowed.name.com"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for authority verification failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getStatus().getDescription()) + .isEqualTo("HostNameVerifier verification failed for authority 'disallowed.name.com'"); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void + perRpcAuthorityOverride_notX509ExtendedTrustManager_badAuthority_flagDisabled_succeeds() + throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("disallowed.name.com"), + SimpleRequest.getDefaultInstance()); + } + @Test public void mtls_succeeds() throws Exception { ServerCredentials serverCreds; @@ -282,6 +620,127 @@ public void hostnameVerifierFails_fails() assertThat(status.getCause()).isInstanceOf(SSLPeerUnverifiedException.class); } + /** Used to simulate the case of X509ExtendedTrustManager not present. */ + private static class FakeTrustManager implements X509TrustManager { + + private final X509TrustManager delegate; + + public FakeTrustManager(X509TrustManager x509ExtendedTrustManager) { + this.delegate = x509ExtendedTrustManager; + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkClientTrusted(x509Certificates, s); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkServerTrusted(x509Certificates, s); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return delegate.getAcceptedIssuers(); + } + } + + /** + * Checks against a limited set of hostnames. In production, EndpointIdentificationAlgorithm is + * unset so the default trust manager will not fail based on the hostname. This class is used to + * test user-provided trust managers that may have their own behavior. + */ + private static class HostnameCheckingX509ExtendedTrustManager + extends ForwardingX509ExtendedTrustManager { + public HostnameCheckingX509ExtendedTrustManager(X509ExtendedTrustManager tm) { + super(tm); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + String peer = ((SSLSocket) socket).getHandshakeSession().getPeerHost(); + if (!"foo.test.google.fr".equals(peer) && !"good.test.google.fr".equals(peer)) { + throw new CertificateException("Peer verification failed."); + } + super.checkServerTrusted(chain, authType, socket); + } + } + + @IgnoreJRERequirement + private static class ForwardingX509ExtendedTrustManager extends X509ExtendedTrustManager { + private final X509ExtendedTrustManager delegate; + + private ForwardingX509ExtendedTrustManager(X509ExtendedTrustManager delegate) { + this.delegate = delegate; + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + delegate.checkServerTrusted(chain, authType, socket); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + delegate.checkServerTrusted(chain, authType, engine); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + delegate.checkServerTrusted(chain, authType); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + delegate.checkClientTrusted(chain, authType, engine); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + delegate.checkClientTrusted(chain, authType); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + delegate.checkClientTrusted(chain, authType, socket); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return delegate.getAcceptedIssuers(); + } + } + + private static Optional getX509ExtendedTrustManager(InputStream rootCerts) + throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts); + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + return Arrays.stream(trustManagerFactory.getTrustManagers()) + .filter(trustManager -> trustManager instanceof X509ExtendedTrustManager).findFirst(); + } + private static Server server(ServerCredentials creds) throws IOException { return OkHttpServerBuilder.forPort(0, creds) .directExecutor() diff --git a/okhttp/src/test/java/io/grpc/okhttp/UtilsTest.java b/okhttp/src/test/java/io/grpc/okhttp/UtilsTest.java index 895ba7ff7c7..1c97e027b4a 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/UtilsTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/UtilsTest.java @@ -16,7 +16,9 @@ package io.grpc.okhttp; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import io.grpc.InternalChannelz.SocketOptions; @@ -25,9 +27,8 @@ import io.grpc.okhttp.internal.TlsVersion; import java.net.Socket; import java.util.List; -import org.junit.Rule; +import java.util.Locale; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,16 +38,12 @@ @RunWith(JUnit4.class) public class UtilsTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void convertSpecRejectsPlaintext() { com.squareup.okhttp.ConnectionSpec plaintext = com.squareup.okhttp.ConnectionSpec.CLEARTEXT; - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("plaintext ConnectionSpec is not accepted"); - Utils.convertSpec(plaintext); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Utils.convertSpec(plaintext)); + assertThat(e).hasMessageThat().isEqualTo("plaintext ConnectionSpec is not accepted"); } @Test @@ -95,6 +92,9 @@ public void getSocketOptions() throws Exception { assertEquals("5000", socketOptions.others.get("SO_SNDBUF")); assertEquals("true", socketOptions.others.get("SO_KEEPALIVE")); assertEquals("true", socketOptions.others.get("SO_OOBINLINE")); - assertEquals("8", socketOptions.others.get("IP_TOS")); + String osName = System.getProperty("os.name").toLowerCase(Locale.ENGLISH); + if (!osName.startsWith("windows")) { + assertEquals("8", socketOptions.others.get("IP_TOS")); + } } } diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/OkHostnameVerifier.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/OkHostnameVerifier.java index 34bb56ee2d6..f6efb2d90e7 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/OkHostnameVerifier.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/OkHostnameVerifier.java @@ -29,10 +29,13 @@ import java.util.List; import java.util.Locale; import java.util.regex.Pattern; +import java.nio.charset.StandardCharsets; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; import javax.security.auth.x500.X500Principal; +import com.google.common.base.Utf8; +import com.google.common.base.Ascii; /** * A HostnameVerifier consistent with altNames = getSubjectAltNames(certificate, ALT_DNS_NAME); for (int i = 0, size = altNames.size(); i < size; i++) { @@ -198,7 +204,7 @@ private boolean verifyHostName(String hostName, String pattern) { } // hostName and pattern are now absolute domain names. - pattern = pattern.toLowerCase(Locale.US); + pattern = Ascii.toLowerCase(pattern); // hostName and pattern are now in lower case -- domain names are case-insensitive. if (!pattern.contains("*")) { @@ -254,4 +260,13 @@ private boolean verifyHostName(String hostName, String pattern) { // hostName matches pattern return true; } + + /** + * Returns true if {@code input} is an ASCII string. + * @param input the string to check. + */ + private static boolean isAscii(String input) { + // Only ASCII characters are 1 byte in UTF-8. + return Utf8.encodedLength(input) == input.length(); + } } diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Platform.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Platform.java index 6ed3bc50b81..29ea8055b26 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Platform.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Platform.java @@ -283,7 +283,7 @@ private static boolean isAtLeastAndroid41() { /** * Select the first recognized security provider according to the preference order returned by - * {@link Security#getProviders}. If a recognized provider is not found then warn but continue. + * {@link Security#getProviders}. */ private static Provider getAndroidSecurityProvider() { Provider[] providers = Security.getProviders(); @@ -295,7 +295,6 @@ private static Provider getAndroidSecurityProvider() { } } } - logger.log(Level.WARNING, "Unable to find Conscrypt"); return null; } diff --git a/opentelemetry/build.gradle b/opentelemetry/build.gradle index 509960e5dbc..594686294f0 100644 --- a/opentelemetry/build.gradle +++ b/opentelemetry/build.gradle @@ -12,17 +12,29 @@ dependencies { implementation libraries.guava, project(':grpc-core'), libraries.opentelemetry.api, - libraries.auto.value.annotations - - testImplementation testFixtures(project(':grpc-core')), - project(':grpc-testing'), + libraries.auto.value.annotations, + libraries.animalsniffer.annotations + + testImplementation project(':grpc-testing'), + project(':grpc-testing-proto'), + project(':grpc-inprocess'), + testFixtures(project(':grpc-core')), + testFixtures(project(':grpc-api')), libraries.opentelemetry.sdk.testing, libraries.assertj.core // opentelemetry.sdk.testing uses compileOnly for assertj annotationProcessor libraries.auto.value - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("jar").configure { diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java index 03183ef4920..87ad61c9f27 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java @@ -18,8 +18,11 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.internal.GrpcUtil.IMPLEMENTATION_VERSION; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.HEDGE_BUCKETS; import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.LATENCY_BUCKETS; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.RETRY_BUCKETS; import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.SIZE_BUCKETS; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.TRANSPARENT_RETRY_BUCKETS; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; @@ -33,16 +36,21 @@ import io.grpc.ManagedChannelBuilder; import io.grpc.MetricSink; import io.grpc.ServerBuilder; +import io.grpc.internal.GrpcUtil; import io.grpc.opentelemetry.internal.OpenTelemetryConstants; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.metrics.Meter; import io.opentelemetry.api.metrics.MeterProvider; +import io.opentelemetry.api.trace.Tracer; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Predicate; +import javax.annotation.Nullable; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** * The entrypoint for OpenTelemetry metrics functionality in gRPC. @@ -61,6 +69,10 @@ public Stopwatch get() { } }; + @VisibleForTesting + static boolean ENABLE_OTEL_TRACING = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_ENABLE_OTEL_TRACING", false); + private final OpenTelemetry openTelemetrySdk; private final MeterProvider meterProvider; private final Meter meter; @@ -68,6 +80,7 @@ public Stopwatch get() { private final boolean disableDefault; private final OpenTelemetryMetricsResource resource; private final OpenTelemetryMetricsModule openTelemetryMetricsModule; + private final OpenTelemetryTracingModule openTelemetryTracingModule; private final List optionalLabels; private final MetricSink sink; @@ -87,7 +100,9 @@ private GrpcOpenTelemetry(Builder builder) { this.resource = createMetricInstruments(meter, enableMetrics, disableDefault); this.optionalLabels = ImmutableList.copyOf(builder.optionalLabels); this.openTelemetryMetricsModule = new OpenTelemetryMetricsModule( - STOPWATCH_SUPPLIER, resource, optionalLabels, builder.plugins); + STOPWATCH_SUPPLIER, resource, optionalLabels, builder.plugins, + builder.targetFilter); + this.openTelemetryTracingModule = new OpenTelemetryTracingModule(openTelemetrySdk); this.sink = new OpenTelemetryMetricSink(meter, enableMetrics, disableDefault, optionalLabels); } @@ -125,6 +140,16 @@ MetricSink getSink() { return sink; } + @VisibleForTesting + Tracer getTracer() { + return this.openTelemetryTracingModule.getTracer(); + } + + @VisibleForTesting + TargetFilter getTargetAttributeFilter() { + return this.openTelemetryMetricsModule.getTargetAttributeFilter(); + } + /** * Registers GrpcOpenTelemetry globally, applying its configuration to all subsequently created * gRPC channels and servers. @@ -152,6 +177,9 @@ public void configureChannelBuilder(ManagedChannelBuilder builder) { InternalManagedChannelBuilder.addMetricSink(builder, sink); InternalManagedChannelBuilder.interceptWithTarget( builder, openTelemetryMetricsModule::getClientInterceptor); + if (ENABLE_OTEL_TRACING) { + builder.intercept(openTelemetryTracingModule.getClientInterceptor()); + } } /** @@ -160,7 +188,15 @@ public void configureChannelBuilder(ManagedChannelBuilder builder) { * @param serverBuilder the server builder to configure */ public void configureServerBuilder(ServerBuilder serverBuilder) { + /* To ensure baggage propagation to metrics, we need the tracing + tracers to be initialised before metrics */ + if (ENABLE_OTEL_TRACING) { + serverBuilder.addStreamTracerFactory( + openTelemetryTracingModule.getServerTracerFactory()); + serverBuilder.intercept(openTelemetryTracingModule.getServerSpanPropagationInterceptor()); + } serverBuilder.addStreamTracerFactory(openTelemetryMetricsModule.getServerTracerFactory()); + serverBuilder.addMetricSink(sink); } @VisibleForTesting @@ -220,6 +256,54 @@ static OpenTelemetryMetricsResource createMetricInstruments(Meter meter, .build()); } + if (isMetricEnabled("grpc.client.call.retries", enableMetrics, disableDefault)) { + builder.clientCallRetriesCounter( + meter.histogramBuilder( + "grpc.client.call.retries") + .setUnit("{retry}") + .setDescription("Number of retries during the client call. " + + "If there were no retries, 0 is not reported.") + .ofLongs() + .setExplicitBucketBoundariesAdvice(RETRY_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.client.call.transparent_retries", enableMetrics, + disableDefault)) { + builder.clientCallTransparentRetriesCounter( + meter.histogramBuilder( + "grpc.client.call.transparent_retries") + .setUnit("{transparent_retry}") + .setDescription("Number of transparent retries during the client call. " + + "If there were no transparent retries, 0 is not reported.") + .ofLongs() + .setExplicitBucketBoundariesAdvice(TRANSPARENT_RETRY_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.client.call.hedges", enableMetrics, disableDefault)) { + builder.clientCallHedgesCounter( + meter.histogramBuilder( + "grpc.client.call.hedges") + .setUnit("{hedge}") + .setDescription("Number of hedges during the client call. " + + "If there were no hedges, 0 is not reported.") + .ofLongs() + .setExplicitBucketBoundariesAdvice(HEDGE_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.client.call.retry_delay", enableMetrics, disableDefault)) { + builder.clientCallRetryDelayCounter( + meter.histogramBuilder( + "grpc.client.call.retry_delay") + .setUnit("s") + .setDescription("Total time of delay while there is no active attempt during the " + + "client call") + .setExplicitBucketBoundariesAdvice(LATENCY_BUCKETS) + .build()); + } + if (isMetricEnabled("grpc.server.call.started", enableMetrics, disableDefault)) { builder.serverCallCountCounter( meter.counterBuilder("grpc.server.call.started") @@ -238,8 +322,8 @@ static OpenTelemetryMetricsResource createMetricInstruments(Meter meter, .build()); } - if (isMetricEnabled("grpc.server.call.sent_total_compressed_message_size", enableMetrics, - disableDefault)) { + if (isMetricEnabled("grpc.server.call.sent_total_compressed_message_size", + enableMetrics, disableDefault)) { builder.serverTotalSentCompressedMessageSizeCounter( meter.histogramBuilder( "grpc.server.call.sent_total_compressed_message_size") @@ -250,8 +334,8 @@ static OpenTelemetryMetricsResource createMetricInstruments(Meter meter, .build()); } - if (isMetricEnabled("grpc.server.call.rcvd_total_compressed_message_size", enableMetrics, - disableDefault)) { + if (isMetricEnabled("grpc.server.call.rcvd_total_compressed_message_size", + enableMetrics, disableDefault)) { builder.serverTotalReceivedCompressedMessageSizeCounter( meter.histogramBuilder( "grpc.server.call.rcvd_total_compressed_message_size") @@ -275,6 +359,13 @@ static boolean isMetricEnabled(String metricName, Map enableMet && !disableDefault; } + /** + * Internal interface to avoid storing a {@link java.util.function.Predicate} directly, ensuring + * compatibility with Android devices (API level < 24) that do not use library desugaring. + */ + interface TargetFilter { + boolean test(String target); + } /** * Builder for configuring {@link GrpcOpenTelemetry}. @@ -285,6 +376,8 @@ public static class Builder { private final Collection optionalLabels = new ArrayList<>(); private final Map enableMetrics = new HashMap<>(); private boolean disableAll; + @Nullable + private TargetFilter targetFilter; private Builder() {} @@ -342,6 +435,31 @@ public Builder disableAllMetrics() { return this; } + Builder enableTracing(boolean enable) { + ENABLE_OTEL_TRACING = enable; + return this; + } + + /** + * Sets an optional filter to control recording of the {@code grpc.target} metric + * attribute. + * + *

If the predicate returns {@code true}, the original target is recorded. Otherwise, + * the target is recorded as {@code "other"} to limit metric cardinality. + * + *

If unset, all targets are recorded as-is. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12595") + @IgnoreJRERequirement + public Builder targetAttributeFilter(@Nullable Predicate filter) { + if (filter == null) { + this.targetFilter = null; + } else { + this.targetFilter = filter::test; + } + return this; + } + /** * Returns a new {@link GrpcOpenTelemetry} built with the configuration of this {@link * Builder}. diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java index 5d5543dddda..ea1e7ab803f 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java @@ -29,4 +29,8 @@ public static void builderPlugin( GrpcOpenTelemetry.Builder builder, InternalOpenTelemetryPlugin plugin) { builder.plugin(plugin); } + + public static void enableTracing(GrpcOpenTelemetry.Builder builder, boolean enable) { + builder.enableTracing(enable); + } } diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricSink.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricSink.java index 8f612804436..fd8af7f998f 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricSink.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricSink.java @@ -27,6 +27,7 @@ import io.grpc.LongCounterMetricInstrument; import io.grpc.LongGaugeMetricInstrument; import io.grpc.LongHistogramMetricInstrument; +import io.grpc.LongUpDownCounterMetricInstrument; import io.grpc.MetricInstrument; import io.grpc.MetricSink; import io.opentelemetry.api.common.Attributes; @@ -36,6 +37,7 @@ import io.opentelemetry.api.metrics.DoubleHistogram; import io.opentelemetry.api.metrics.LongCounter; import io.opentelemetry.api.metrics.LongHistogram; +import io.opentelemetry.api.metrics.LongUpDownCounter; import io.opentelemetry.api.metrics.Meter; import io.opentelemetry.api.metrics.ObservableLongMeasurement; import io.opentelemetry.api.metrics.ObservableMeasurement; @@ -117,6 +119,22 @@ public void addLongCounter(LongCounterMetricInstrument metricInstrument, long va counter.add(value, attributes); } + @Override + public void addLongUpDownCounter(LongUpDownCounterMetricInstrument metricInstrument, long value, + List requiredLabelValues, + List optionalLabelValues) { + MeasuresData instrumentData = measures.get(metricInstrument.getIndex()); + if (instrumentData == null) { + // Disabled metric + return; + } + Attributes attributes = createAttributes(metricInstrument.getRequiredLabelKeys(), + metricInstrument.getOptionalLabelKeys(), requiredLabelValues, optionalLabelValues, + instrumentData.getOptionalLabelsBitSet()); + LongUpDownCounter counter = (LongUpDownCounter) instrumentData.getMeasure(); + counter.add(value, attributes); + } + @Override public void recordDoubleHistogram(DoubleHistogramMetricInstrument metricInstrument, double value, List requiredLabelValues, List optionalLabelValues) { @@ -256,6 +274,11 @@ public void updateMeasures(List instruments) { .setDescription(description) .ofLongs() .buildObserver(); + } else if (instrument instanceof LongUpDownCounterMetricInstrument) { + openTelemetryMeasure = openTelemetryMeter.upDownCounterBuilder(name) + .setUnit(unit) + .setDescription(description) + .build(); } else { logger.log(Level.FINE, "Unsupported metric instrument type : {0}", instrument); openTelemetryMeasure = null; diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java index f631da59d01..c9e623b4415 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java @@ -17,6 +17,9 @@ package io.grpc.opentelemetry; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.BACKEND_SERVICE_KEY; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.BAGGAGE_KEY; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.CUSTOM_LABEL_KEY; import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.LOCALITY_KEY; import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.METHOD_KEY; import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.STATUS_KEY; @@ -27,6 +30,7 @@ import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -36,13 +40,17 @@ import io.grpc.Deadline; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.Grpc; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StreamTracer; +import io.grpc.opentelemetry.GrpcOpenTelemetry.TargetFilter; +import io.opentelemetry.api.baggage.Baggage; import io.opentelemetry.api.common.AttributesBuilder; +import io.opentelemetry.context.Context; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -54,7 +62,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Provides factories for {@link StreamTracer} that records metrics to OpenTelemetry. @@ -64,13 +71,16 @@ * tracer. It's the tracer that reports per-attempt stats, and the factory that reports the stats * of the overall RPC, such as RETRIES_PER_CALL, to OpenTelemetry. * + *

This module optionally applies a target attribute filter to limit the cardinality of + * the {@code grpc.target} attribute in client-side metrics by mapping disallowed targets + * to a stable placeholder value. + * *

On the server-side, there is only one ServerStream per each ServerCall, and ServerStream * starts earlier than the ServerCall. Therefore, only one tracer is created per stream/call, and * it's the tracer that reports the summary to OpenTelemetry. */ final class OpenTelemetryMetricsModule { private static final Logger logger = Logger.getLogger(OpenTelemetryMetricsModule.class.getName()); - private static final String LOCALITY_LABEL_NAME = "grpc.lb.locality"; public static final ImmutableSet DEFAULT_PER_CALL_METRICS_SET = ImmutableSet.of( "grpc.client.attempt.started", @@ -90,15 +100,34 @@ final class OpenTelemetryMetricsModule { private final OpenTelemetryMetricsResource resource; private final Supplier stopwatchSupplier; private final boolean localityEnabled; + private final boolean backendServiceEnabled; + private final boolean customLabelEnabled; private final ImmutableList plugins; + @Nullable + private final TargetFilter targetAttributeFilter; + + OpenTelemetryMetricsModule(Supplier stopwatchSupplier, + OpenTelemetryMetricsResource resource, + Collection optionalLabels, List plugins) { + this(stopwatchSupplier, resource, optionalLabels, plugins, null); + } OpenTelemetryMetricsModule(Supplier stopwatchSupplier, - OpenTelemetryMetricsResource resource, Collection optionalLabels, - List plugins) { + OpenTelemetryMetricsResource resource, + Collection optionalLabels, List plugins, + @Nullable TargetFilter targetAttributeFilter) { this.resource = checkNotNull(resource, "resource"); this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier"); - this.localityEnabled = optionalLabels.contains(LOCALITY_LABEL_NAME); + this.localityEnabled = optionalLabels.contains(LOCALITY_KEY.getKey()); + this.backendServiceEnabled = optionalLabels.contains(BACKEND_SERVICE_KEY.getKey()); + this.customLabelEnabled = optionalLabels.contains(CUSTOM_LABEL_KEY.getKey()); this.plugins = ImmutableList.copyOf(plugins); + this.targetAttributeFilter = targetAttributeFilter; + } + + @VisibleForTesting + TargetFilter getTargetAttributeFilter() { + return targetAttributeFilter; } /** @@ -119,7 +148,15 @@ ClientInterceptor getClientInterceptor(String target) { pluginBuilder.add(plugin); } } - return new MetricsClientInterceptor(target, pluginBuilder.build()); + String filteredTarget = recordTarget(target); + return new MetricsClientInterceptor(filteredTarget, pluginBuilder.build()); + } + + String recordTarget(String target) { + if (targetAttributeFilter == null || target == null) { + return target; + } + return targetAttributeFilter.test(target) ? target : "other"; } static String recordMethodName(String fullMethodName, boolean isGeneratedMethod) { @@ -162,6 +199,7 @@ private static final class ClientTracer extends ClientStreamTracer { volatile long outboundWireSize; volatile long inboundWireSize; volatile String locality; + volatile String backendService; long attemptNanos; Code statusCode; @@ -206,9 +244,12 @@ public void inboundWireSize(long bytes) { @Override public void addOptionalLabel(String key, String value) { - if (LOCALITY_LABEL_NAME.equals(key)) { + if ("grpc.lb.locality".equals(key)) { locality = value; } + if ("grpc.lb.backend_service".equals(key)) { + backendService = value; + } } @Override @@ -232,7 +273,7 @@ public void streamClosed(Status status) { statusCode = Code.DEADLINE_EXCEEDED; } } - attemptsState.attemptEnded(); + attemptsState.attemptEnded(info.getCallOptions()); recordFinishedAttempt(); } @@ -248,6 +289,17 @@ void recordFinishedAttempt() { } builder.put(LOCALITY_KEY, savedLocality); } + if (module.backendServiceEnabled) { + String savedBackendService = backendService; + if (savedBackendService == null) { + savedBackendService = ""; + } + builder.put(BACKEND_SERVICE_KEY, savedBackendService); + } + if (module.customLabelEnabled) { + builder.put( + CUSTOM_LABEL_KEY, info.getCallOptions().getOption(Grpc.CALL_OPTION_CUSTOM_LABEL)); + } for (OpenTelemetryPlugin.ClientStreamPlugin plugin : streamPlugins) { plugin.addLabels(builder); } @@ -255,15 +307,15 @@ void recordFinishedAttempt() { if (module.resource.clientAttemptDurationCounter() != null ) { module.resource.clientAttemptDurationCounter() - .record(attemptNanos * SECONDS_PER_NANO, attribute); + .record(attemptNanos * SECONDS_PER_NANO, attribute, attemptsState.otelContext); } if (module.resource.clientTotalSentCompressedMessageSizeCounter() != null) { module.resource.clientTotalSentCompressedMessageSizeCounter() - .record(outboundWireSize, attribute); + .record(outboundWireSize, attribute, attemptsState.otelContext); } if (module.resource.clientTotalReceivedCompressedMessageSizeCounter() != null) { module.resource.clientTotalReceivedCompressedMessageSizeCounter() - .record(inboundWireSize, attribute); + .record(inboundWireSize, attribute, attemptsState.otelContext); } } } @@ -272,16 +324,20 @@ void recordFinishedAttempt() { static final class CallAttemptsTracerFactory extends ClientStreamTracer.Factory { private final OpenTelemetryMetricsModule module; private final String target; - private final Stopwatch attemptStopwatch; + private final Stopwatch attemptDelayStopwatch; private final Stopwatch callStopWatch; @GuardedBy("lock") private boolean callEnded; private final String fullMethodName; private final List callPlugins; + private final Context otelContext; private Status status; + private long retryDelayNanos; private long callLatencyNanos; private final Object lock = new Object(); private final AtomicLong attemptsPerCall = new AtomicLong(); + private final AtomicLong hedgedAttemptsPerCall = new AtomicLong(); + private final AtomicLong transparentRetriesPerCall = new AtomicLong(); @GuardedBy("lock") private int activeStreams; @GuardedBy("lock") @@ -290,22 +346,29 @@ static final class CallAttemptsTracerFactory extends ClientStreamTracer.Factory CallAttemptsTracerFactory( OpenTelemetryMetricsModule module, String target, + CallOptions callOptions, String fullMethodName, - List callPlugins) { + List callPlugins, Context otelContext) { this.module = checkNotNull(module, "module"); this.target = checkNotNull(target, "target"); this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName"); this.callPlugins = checkNotNull(callPlugins, "callPlugins"); - this.attemptStopwatch = module.stopwatchSupplier.get(); + this.otelContext = checkNotNull(otelContext, "otelContext"); + this.attemptDelayStopwatch = module.stopwatchSupplier.get(); this.callStopWatch = module.stopwatchSupplier.get().start(); - io.opentelemetry.api.common.Attributes attribute = io.opentelemetry.api.common.Attributes.of( - METHOD_KEY, fullMethodName, - TARGET_KEY, target); + AttributesBuilder builder = io.opentelemetry.api.common.Attributes.builder() + .put(METHOD_KEY, fullMethodName) + .put(TARGET_KEY, target); + if (module.customLabelEnabled) { + builder.put( + CUSTOM_LABEL_KEY, callOptions.getOption(Grpc.CALL_OPTION_CUSTOM_LABEL)); + } + io.opentelemetry.api.common.Attributes attribute = builder.build(); // Record here in case mewClientStreamTracer() would never be called. if (module.resource.clientAttemptCountCounter() != null) { - module.resource.clientAttemptCountCounter().add(1, attribute); + module.resource.clientAttemptCountCounter().add(1, attribute, otelContext); } } @@ -316,22 +379,32 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata metada // This can be the case when the call is cancelled but a retry attempt is created. return new ClientStreamTracer() {}; } - if (++activeStreams == 1 && attemptStopwatch.isRunning()) { - attemptStopwatch.stop(); + if (++activeStreams == 1 && attemptDelayStopwatch.isRunning()) { + attemptDelayStopwatch.stop(); + retryDelayNanos = attemptDelayStopwatch.elapsed(TimeUnit.NANOSECONDS); } } // Skip recording for the first time, since it is already recorded in // CallAttemptsTracerFactory constructor. attemptsPerCall will be non-zero after the first // attempt, as first attempt cannot be a transparent retry. if (attemptsPerCall.get() > 0) { - io.opentelemetry.api.common.Attributes attribute = - io.opentelemetry.api.common.Attributes.of(METHOD_KEY, fullMethodName, - TARGET_KEY, target); + AttributesBuilder builder = io.opentelemetry.api.common.Attributes.builder() + .put(METHOD_KEY, fullMethodName) + .put(TARGET_KEY, target); + if (module.customLabelEnabled) { + builder.put( + CUSTOM_LABEL_KEY, info.getCallOptions().getOption(Grpc.CALL_OPTION_CUSTOM_LABEL)); + } + io.opentelemetry.api.common.Attributes attribute = builder.build(); if (module.resource.clientAttemptCountCounter() != null) { - module.resource.clientAttemptCountCounter().add(1, attribute); + module.resource.clientAttemptCountCounter().add(1, attribute, otelContext); } } - if (!info.isTransparentRetry()) { + if (info.isTransparentRetry()) { + transparentRetriesPerCall.incrementAndGet(); + } else if (info.isHedging()) { + hedgedAttemptsPerCall.incrementAndGet(); + } else { attemptsPerCall.incrementAndGet(); } return newClientTracer(info); @@ -350,11 +423,11 @@ private ClientTracer newClientTracer(StreamInfo info) { } // Called whenever each attempt is ended. - void attemptEnded() { + void attemptEnded(CallOptions callOptions) { boolean shouldRecordFinishedCall = false; synchronized (lock) { if (--activeStreams == 0) { - attemptStopwatch.start(); + attemptDelayStopwatch.start(); if (callEnded && !finishedCallToBeRecorded) { shouldRecordFinishedCall = true; finishedCallToBeRecorded = true; @@ -362,11 +435,11 @@ void attemptEnded() { } } if (shouldRecordFinishedCall) { - recordFinishedCall(); + recordFinishedCall(callOptions); } } - void callEnded(Status status) { + void callEnded(Status status, CallOptions callOptions) { callStopWatch.stop(); this.status = status; boolean shouldRecordFinishedCall = false; @@ -382,26 +455,73 @@ void callEnded(Status status) { } } if (shouldRecordFinishedCall) { - recordFinishedCall(); + recordFinishedCall(callOptions); } } - void recordFinishedCall() { + void recordFinishedCall(CallOptions callOptions) { if (attemptsPerCall.get() == 0) { ClientTracer tracer = newClientTracer(null); - tracer.attemptNanos = attemptStopwatch.elapsed(TimeUnit.NANOSECONDS); + tracer.attemptNanos = attemptDelayStopwatch.elapsed(TimeUnit.NANOSECONDS); tracer.statusCode = status.getCode(); tracer.recordFinishedAttempt(); } callLatencyNanos = callStopWatch.elapsed(TimeUnit.NANOSECONDS); - io.opentelemetry.api.common.Attributes attribute = - io.opentelemetry.api.common.Attributes.of(METHOD_KEY, fullMethodName, - TARGET_KEY, target, - STATUS_KEY, status.getCode().toString()); + // Base attributes + AttributesBuilder builder = io.opentelemetry.api.common.Attributes.builder() + .put(METHOD_KEY, fullMethodName) + .put(TARGET_KEY, target); + if (module.customLabelEnabled) { + builder.put(CUSTOM_LABEL_KEY, callOptions.getOption(Grpc.CALL_OPTION_CUSTOM_LABEL)); + } + io.opentelemetry.api.common.Attributes baseAttributes = builder.build(); + + // Duration if (module.resource.clientCallDurationCounter() != null) { - module.resource.clientCallDurationCounter() - .record(callLatencyNanos * SECONDS_PER_NANO, attribute); + module.resource.clientCallDurationCounter().record( + callLatencyNanos * SECONDS_PER_NANO, + baseAttributes.toBuilder() + .put(STATUS_KEY, status.getCode().toString()) + .build(), + otelContext + ); + } + + // Retry counts + if (module.resource.clientCallRetriesCounter() != null) { + long retriesPerCall = Math.max(attemptsPerCall.get() - 1, 0); + if (retriesPerCall > 0) { + module.resource.clientCallRetriesCounter() + .record(retriesPerCall, baseAttributes, otelContext); + } + } + + // Hedge counts + if (module.resource.clientCallHedgesCounter() != null) { + long hedges = hedgedAttemptsPerCall.get(); + if (hedges > 0) { + module.resource.clientCallHedgesCounter() + .record(hedges, baseAttributes, otelContext); + } + } + + // Transparent Retry counts + if (module.resource.clientCallTransparentRetriesCounter() != null) { + long transparentRetries = transparentRetriesPerCall.get(); + if (transparentRetries > 0) { + module.resource.clientCallTransparentRetriesCounter() + .record(transparentRetries, baseAttributes, otelContext); + } + } + + // Retry delay + if (module.resource.clientCallRetryDelayCounter() != null) { + module.resource.clientCallRetryDelayCounter().record( + retryDelayNanos * SECONDS_PER_NANO, + baseAttributes, + otelContext + ); } } } @@ -441,6 +561,7 @@ private static final class ServerTracer extends ServerStreamTracer { private final OpenTelemetryMetricsModule module; private final String fullMethodName; private final List streamPlugins; + private Context otelContext = Context.root(); private volatile boolean isGeneratedMethod; private volatile int streamClosed; private final Stopwatch stopwatch; @@ -455,6 +576,17 @@ private static final class ServerTracer extends ServerStreamTracer { this.stopwatch = module.stopwatchSupplier.get().start(); } + @Override + public io.grpc.Context filterContext(io.grpc.Context context) { + Baggage baggage = BAGGAGE_KEY.get(context); + if (baggage != null) { + otelContext = Context.current().with(baggage); + } else { + otelContext = Context.current(); + } + return context; + } + @Override public void serverCallStarted(ServerCallInfo callInfo) { // Only record method name as an attribute if isSampledToLocalTracing is set to true, @@ -462,12 +594,13 @@ public void serverCallStarted(ServerCallInfo callInfo) { // created methods result in high cardinality metrics. boolean isSampledToLocalTracing = callInfo.getMethodDescriptor().isSampledToLocalTracing(); isGeneratedMethod = isSampledToLocalTracing; + io.opentelemetry.api.common.Attributes attribute = io.opentelemetry.api.common.Attributes.of( METHOD_KEY, recordMethodName(fullMethodName, isSampledToLocalTracing)); if (module.resource.serverCallCountCounter() != null) { - module.resource.serverCallCountCounter().add(1, attribute); + module.resource.serverCallCountCounter().add(1, attribute, otelContext); } } @@ -521,15 +654,15 @@ public void streamClosed(Status status) { if (module.resource.serverCallDurationCounter() != null) { module.resource.serverCallDurationCounter() - .record(elapsedTimeNanos * SECONDS_PER_NANO, attributes); + .record(elapsedTimeNanos * SECONDS_PER_NANO, attributes, otelContext); } if (module.resource.serverTotalSentCompressedMessageSizeCounter() != null) { module.resource.serverTotalSentCompressedMessageSizeCounter() - .record(outboundWireSize, attributes); + .record(outboundWireSize, attributes, otelContext); } if (module.resource.serverTotalReceivedCompressedMessageSizeCounter() != null) { module.resource.serverTotalReceivedCompressedMessageSizeCounter() - .record(inboundWireSize, attributes); + .record(inboundWireSize, attributes, otelContext); } } } @@ -549,7 +682,8 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata } streamPlugins = Collections.unmodifiableList(streamPluginsMutable); } - return new ServerTracer(OpenTelemetryMetricsModule.this, fullMethodName, streamPlugins); + return new ServerTracer(OpenTelemetryMetricsModule.this, fullMethodName, + streamPlugins); } } @@ -580,13 +714,14 @@ public ClientCall interceptCall( callOptions = plugin.filterCallOptions(callOptions); } } + final CallOptions finalCallOptions = callOptions; // Only record method name as an attribute if isSampledToLocalTracing is set to true, // which is true for all generated methods. Otherwise, programatically // created methods result in high cardinality metrics. final CallAttemptsTracerFactory tracerFactory = new CallAttemptsTracerFactory( - OpenTelemetryMetricsModule.this, target, + OpenTelemetryMetricsModule.this, target, callOptions, recordMethodName(method.getFullMethodName(), method.isSampledToLocalTracing()), - callPlugins); + callPlugins, Context.current()); ClientCall call = next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); return new SimpleForwardingClientCall(call) { @@ -599,7 +734,7 @@ public void start(Listener responseListener, Metadata headers) { new SimpleForwardingClientCallListener(responseListener) { @Override public void onClose(Status status, Metadata trailers) { - tracerFactory.callEnded(status); + tracerFactory.callEnded(status, finalCallOptions); super.onClose(status, trailers); } }, @@ -609,3 +744,4 @@ public void onClose(Status status, Metadata trailers) { } } } + diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsResource.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsResource.java index e519b7e1eb6..d32ae1e67f5 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsResource.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsResource.java @@ -41,6 +41,17 @@ abstract class OpenTelemetryMetricsResource { @Nullable abstract LongHistogram clientTotalReceivedCompressedMessageSizeCounter(); + @Nullable + abstract LongHistogram clientCallRetriesCounter(); + + @Nullable + abstract LongHistogram clientCallTransparentRetriesCounter(); + + @Nullable + abstract LongHistogram clientCallHedgesCounter(); + + @Nullable + abstract DoubleHistogram clientCallRetryDelayCounter(); /* Server Metrics */ @Nullable @@ -73,6 +84,14 @@ abstract static class Builder { abstract Builder clientTotalReceivedCompressedMessageSizeCounter( LongHistogram counter); + abstract Builder clientCallRetriesCounter(LongHistogram counter); + + abstract Builder clientCallTransparentRetriesCounter(LongHistogram counter); + + abstract Builder clientCallHedgesCounter(LongHistogram counter); + + abstract Builder clientCallRetryDelayCounter(DoubleHistogram counter); + abstract Builder serverCallCountCounter(LongCounter counter); abstract Builder serverCallDurationCounter(DoubleHistogram counter); diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java index 11659c87708..d214e99bd75 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java @@ -18,6 +18,8 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ClientStreamTracer.NAME_RESOLUTION_DELAYED; +import static io.grpc.internal.GrpcUtil.IMPLEMENTATION_VERSION; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.BAGGAGE_KEY; import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; @@ -28,15 +30,23 @@ import io.grpc.ClientStreamTracer; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.ForwardingServerCallListener; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.ServerStreamTracer; +import io.grpc.internal.GrpcUtil; +import io.grpc.opentelemetry.internal.OpenTelemetryConstants; import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.baggage.Baggage; import io.opentelemetry.api.common.AttributesBuilder; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; import io.opentelemetry.context.propagation.ContextPropagators; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.logging.Level; @@ -50,7 +60,8 @@ final class OpenTelemetryTracingModule { private static final Logger logger = Logger.getLogger(OpenTelemetryTracingModule.class.getName()); @VisibleForTesting - static final String OTEL_TRACING_SCOPE_NAME = "grpc-java"; + final io.grpc.Context.Key otelSpan = io.grpc.Context.key("opentelemetry-span-key"); + @Nullable private static final AtomicIntegerFieldUpdater callEndedUpdater; @Nullable @@ -83,13 +94,23 @@ final class OpenTelemetryTracingModule { private final MetadataGetter metadataGetter = MetadataGetter.getInstance(); private final MetadataSetter metadataSetter = MetadataSetter.getInstance(); private final TracingClientInterceptor clientInterceptor = new TracingClientInterceptor(); + private final ServerInterceptor serverSpanPropagationInterceptor = + new TracingServerSpanPropagationInterceptor(); private final ServerTracerFactory serverTracerFactory = new ServerTracerFactory(); OpenTelemetryTracingModule(OpenTelemetry openTelemetry) { - this.otelTracer = checkNotNull(openTelemetry.getTracer(OTEL_TRACING_SCOPE_NAME), "otelTracer"); + this.otelTracer = checkNotNull(openTelemetry.getTracerProvider(), "tracerProvider") + .tracerBuilder(OpenTelemetryConstants.INSTRUMENTATION_SCOPE) + .setInstrumentationVersion(IMPLEMENTATION_VERSION) + .build(); this.contextPropagators = checkNotNull(openTelemetry.getPropagators(), "contextPropagators"); } + @VisibleForTesting + Tracer getTracer() { + return otelTracer; + } + /** * Creates a {@link CallAttemptsTracerFactory} for a new call. */ @@ -112,6 +133,10 @@ ClientInterceptor getClientInterceptor() { return clientInterceptor; } + ServerInterceptor getServerSpanPropagationInterceptor() { + return serverSpanPropagationInterceptor; + } + @VisibleForTesting final class CallAttemptsTracerFactory extends ClientStreamTracer.Factory { volatile int callEnded; @@ -196,7 +221,6 @@ public void outboundMessageSent( @Override public void inboundMessageRead( int seqNo, long optionalWireSize, long optionalUncompressedSize) { - //TODO(yifeizhuang): needs support from message deframer. if (optionalWireSize != optionalUncompressedSize) { recordInboundCompressedMessage(span, seqNo, optionalWireSize); } @@ -222,13 +246,15 @@ private final class ServerTracer extends ServerStreamTracer { private final Span span; volatile int streamClosed; private int seqNo; + private Baggage baggage; - ServerTracer(String fullMethodName, @Nullable Span remoteSpan) { + ServerTracer(String fullMethodName, @Nullable Span remoteSpan, Baggage baggage) { checkNotNull(fullMethodName, "fullMethodName"); this.span = otelTracer.spanBuilder(generateTraceSpanName(true, fullMethodName)) .setParent(remoteSpan == null ? null : Context.current().with(remoteSpan)) .startSpan(); + this.baggage = baggage; } /** @@ -252,6 +278,13 @@ public void streamClosed(io.grpc.Status status) { endSpanWithStatus(span, status); } + @Override + public io.grpc.Context filterContext(io.grpc.Context context) { + return context + .withValue(otelSpan, span) + .withValue(BAGGAGE_KEY, baggage); + } + @Override public void outboundMessageSent( int seqNo, long optionalWireSize, long optionalUncompressedSize) { @@ -289,7 +322,79 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata if (remoteSpan == Span.getInvalid()) { remoteSpan = null; } - return new ServerTracer(fullMethodName, remoteSpan); + Baggage baggage = Baggage.fromContext(context); + return new ServerTracer(fullMethodName, remoteSpan, baggage); + } + } + + @VisibleForTesting + final class TracingServerSpanPropagationInterceptor implements ServerInterceptor { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + Span span = otelSpan.get(io.grpc.Context.current()); + if (span == null) { + logger.log(Level.FINE, "Server span not found. ServerTracerFactory for server " + + "tracing must be set."); + return next.startCall(call, headers); + } + Context serverCallContext = Context.current(); + serverCallContext = serverCallContext.with(span); + Baggage baggage = BAGGAGE_KEY.get(); + if (baggage != null) { + serverCallContext = serverCallContext.with(baggage); + } else { + logger.log(Level.WARNING, "Server baggage not found which is unexpected, " + + "as it is being added unconditionally in filterContext()."); + } + try (Scope scope = serverCallContext.makeCurrent()) { + return new ContextServerCallListener<>(next.startCall(call, headers), serverCallContext); + } + } + } + + private static class ContextServerCallListener extends + ForwardingServerCallListener.SimpleForwardingServerCallListener { + private final Context context; + + protected ContextServerCallListener(ServerCall.Listener delegate, Context context) { + super(delegate); + this.context = checkNotNull(context, "context"); + } + + @Override + public void onMessage(ReqT message) { + try (Scope scope = context.makeCurrent()) { + delegate().onMessage(message); + } + } + + @Override + public void onHalfClose() { + try (Scope scope = context.makeCurrent()) { + delegate().onHalfClose(); + } + } + + @Override + public void onCancel() { + try (Scope scope = context.makeCurrent()) { + delegate().onCancel(); + } + } + + @Override + public void onComplete() { + try (Scope scope = context.makeCurrent()) { + delegate().onComplete(); + } + } + + @Override + public void onReady() { + try (Scope scope = context.makeCurrent()) { + delegate().onReady(); + } } } @@ -358,7 +463,7 @@ private void recordOutboundMessageSentEvent(Span span, if (optionalWireSize != -1 && optionalWireSize != optionalUncompressedSize) { attributesBuilder.put("message-size-compressed", optionalWireSize); } - span.addEvent("Outbound message sent", attributesBuilder.build()); + span.addEvent("Outbound message", attributesBuilder.build()); } private void recordInboundCompressedMessage(Span span, int seqNo, long optionalWireSize) { @@ -372,22 +477,14 @@ private void recordInboundMessageSize(Span span, int seqNo, long bytes) { AttributesBuilder attributesBuilder = io.opentelemetry.api.common.Attributes.builder(); attributesBuilder.put("sequence-number", seqNo); attributesBuilder.put("message-size", bytes); - span.addEvent("Inbound message received", attributesBuilder.build()); - } - - private String generateErrorStatusDescription(io.grpc.Status status) { - if (status.getDescription() != null) { - return status.getCode() + ": " + status.getDescription(); - } else { - return status.getCode().toString(); - } + span.addEvent("Inbound message", attributesBuilder.build()); } private void endSpanWithStatus(Span span, io.grpc.Status status) { if (status.isOk()) { span.setStatus(StatusCode.OK); } else { - span.setStatus(StatusCode.ERROR, generateErrorStatusDescription(status)); + span.setStatus(StatusCode.ERROR, GrpcUtil.statusToPrettyString(status)); } span.end(); } diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/internal/OpenTelemetryConstants.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/internal/OpenTelemetryConstants.java index 081e376b8c5..c09a1a2beca 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/internal/OpenTelemetryConstants.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/internal/OpenTelemetryConstants.java @@ -16,7 +16,9 @@ package io.grpc.opentelemetry.internal; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import io.opentelemetry.api.baggage.Baggage; import io.opentelemetry.api.common.AttributeKey; import java.util.List; @@ -33,6 +35,22 @@ public final class OpenTelemetryConstants { public static final AttributeKey LOCALITY_KEY = AttributeKey.stringKey("grpc.lb.locality"); + public static final AttributeKey BACKEND_SERVICE_KEY = + AttributeKey.stringKey("grpc.lb.backend_service"); + + public static final AttributeKey CUSTOM_LABEL_KEY = + AttributeKey.stringKey("grpc.client.call.custom"); + + public static final AttributeKey DISCONNECT_ERROR_KEY = + AttributeKey.stringKey("grpc.disconnect_error"); + + public static final AttributeKey SECURITY_LEVEL_KEY = + AttributeKey.stringKey("grpc.security_level"); + + @VisibleForTesting + public static final io.grpc.Context.Key BAGGAGE_KEY = + io.grpc.Context.key("opentelemetry-baggage-key"); + public static final List LATENCY_BUCKETS = ImmutableList.of( 0d, 0.00001d, 0.00005d, 0.0001d, 0.0003d, 0.0006d, 0.0008d, 0.001d, 0.002d, @@ -46,6 +64,13 @@ public final class OpenTelemetryConstants { 0L, 1024L, 2048L, 4096L, 16384L, 65536L, 262144L, 1048576L, 4194304L, 16777216L, 67108864L, 268435456L, 1073741824L, 4294967296L); + public static final List RETRY_BUCKETS = ImmutableList.of(1L, 2L, 3L, 4L, 5L); + + public static final List TRANSPARENT_RETRY_BUCKETS = + ImmutableList.of(1L, 2L, 3L, 4L, 5L, 10L); + + public static final List HEDGE_BUCKETS = ImmutableList.of(1L, 2L, 3L, 4L, 5L); + private OpenTelemetryConstants() { } } diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcOpenTelemetryTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcOpenTelemetryTest.java index e4a0fa46e8b..f0bd6f93098 100644 --- a/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcOpenTelemetryTest.java +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcOpenTelemetryTest.java @@ -17,15 +17,26 @@ package io.grpc.opentelemetry; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import com.google.common.collect.ImmutableList; -import io.grpc.MetricSink; +import io.grpc.ClientInterceptor; +import io.grpc.ManagedChannelBuilder; +import io.grpc.ServerBuilder; import io.grpc.internal.GrpcUtil; +import io.grpc.opentelemetry.GrpcOpenTelemetry.TargetFilter; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.sdk.OpenTelemetrySdk; import io.opentelemetry.sdk.metrics.SdkMeterProvider; import io.opentelemetry.sdk.testing.exporter.InMemoryMetricReader; +import io.opentelemetry.sdk.trace.SdkTracerProvider; import java.util.Arrays; +import org.junit.After; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -35,7 +46,19 @@ public class GrpcOpenTelemetryTest { private final InMemoryMetricReader inMemoryMetricReader = InMemoryMetricReader.create(); private final SdkMeterProvider meterProvider = SdkMeterProvider.builder().registerMetricReader(inMemoryMetricReader).build(); + private final SdkTracerProvider tracerProvider = SdkTracerProvider.builder().build(); private final OpenTelemetry noopOpenTelemetry = OpenTelemetry.noop(); + private boolean originalEnableOtelTracing; + + @Before + public void setup() { + originalEnableOtelTracing = GrpcOpenTelemetry.ENABLE_OTEL_TRACING; + } + + @After + public void tearDown() { + GrpcOpenTelemetry.ENABLE_OTEL_TRACING = originalEnableOtelTracing; + } @Test public void build() { @@ -56,6 +79,32 @@ public void build() { assertThat(openTelemetryModule.getOptionalLabels()).isEqualTo(ImmutableList.of("version")); } + @Test + public void buildTracer() { + OpenTelemetrySdk sdk = + OpenTelemetrySdk.builder().setTracerProvider(tracerProvider).build(); + + GrpcOpenTelemetry grpcOpenTelemetry = GrpcOpenTelemetry.newBuilder() + .enableTracing(true) + .sdk(sdk).build(); + + assertThat(grpcOpenTelemetry.getOpenTelemetryInstance()).isSameInstanceAs(sdk); + assertThat(grpcOpenTelemetry.getTracer()).isSameInstanceAs( + tracerProvider.tracerBuilder("grpc-java") + .setInstrumentationVersion(GrpcUtil.IMPLEMENTATION_VERSION) + .build()); + ServerBuilder mockServerBuiler = mock(ServerBuilder.class); + grpcOpenTelemetry.configureServerBuilder(mockServerBuiler); + verify(mockServerBuiler, times(2)).addStreamTracerFactory(any()); + verify(mockServerBuiler).intercept(any()); + verify(mockServerBuiler).addMetricSink(any()); + verifyNoMoreInteractions(mockServerBuiler); + + ManagedChannelBuilder mockChannelBuilder = mock(ManagedChannelBuilder.class); + grpcOpenTelemetry.configureChannelBuilder(mockChannelBuilder); + verify(mockChannelBuilder).intercept(any(ClientInterceptor.class)); + } + @Test public void builderDefaults() { GrpcOpenTelemetry module = GrpcOpenTelemetry.newBuilder().build(); @@ -72,7 +121,25 @@ public void builderDefaults() { .build()); assertThat(module.getEnableMetrics()).isEmpty(); assertThat(module.getOptionalLabels()).isEmpty(); - assertThat(module.getSink()).isInstanceOf(MetricSink.class); + + assertThat(module.getTracer()).isSameInstanceAs(noopOpenTelemetry + .getTracerProvider() + .tracerBuilder("grpc-java") + .setInstrumentationVersion(GrpcUtil.IMPLEMENTATION_VERSION) + .build() + ); + } + + @Test + public void builderTargetAttributeFilter() { + GrpcOpenTelemetry module = GrpcOpenTelemetry.newBuilder() + .targetAttributeFilter(t -> t.contains("allowed.com")) + .build(); + + TargetFilter internalFilter = module.getTargetAttributeFilter(); + + assertThat(internalFilter.test("allowed.com")).isTrue(); + assertThat(internalFilter.test("example.com")).isFalse(); } @Test diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricSinkTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricSinkTest.java index c538da55dcb..cced4de3cb4 100644 --- a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricSinkTest.java +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricSinkTest.java @@ -24,6 +24,7 @@ import io.grpc.LongCounterMetricInstrument; import io.grpc.LongGaugeMetricInstrument; import io.grpc.LongHistogramMetricInstrument; +import io.grpc.LongUpDownCounterMetricInstrument; import io.grpc.MetricInstrument; import io.grpc.MetricSink; import io.grpc.opentelemetry.internal.OpenTelemetryConstants; @@ -144,16 +145,25 @@ public void addCounter_enabledMetric() { "Number of client calls started", "count", Collections.emptyList(), Collections.emptyList(), true); + LongUpDownCounterMetricInstrument longUpDownCounterInstrument = + new LongUpDownCounterMetricInstrument(2, "active_carrier_pigeons", + "Active Carrier Pigeons", "pigeons", + Collections.emptyList(), + Collections.emptyList(), true); + // Create sink sink = new OpenTelemetryMetricSink(testMeter, enabledMetrics, false, Collections.emptyList()); // Invoke updateMeasures - sink.updateMeasures(Arrays.asList(longCounterInstrument, doubleCounterInstrument)); + sink.updateMeasures(Arrays.asList(longCounterInstrument, doubleCounterInstrument, + longUpDownCounterInstrument)); sink.addLongCounter(longCounterInstrument, 123L, Collections.emptyList(), Collections.emptyList()); sink.addDoubleCounter(doubleCounterInstrument, 12.0, Collections.emptyList(), Collections.emptyList()); + sink.addLongUpDownCounter(longUpDownCounterInstrument, -3L, Collections.emptyList(), + Collections.emptyList()); assertThat(openTelemetryTesting.getMetrics()) .satisfiesExactlyInAnyOrder( @@ -184,7 +194,21 @@ public void addCounter_enabledMetric() { .hasPointsSatisfying( point -> point - .hasValue(12.0D)))); + .hasValue(12.0D))), + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName("active_carrier_pigeons") + .hasDescription("Active Carrier Pigeons") + .hasUnit("pigeons") + .hasLongSumSatisfying( + longSum -> + longSum + .hasPointsSatisfying( + point -> + point + .hasValue(-3L)))); } @Test @@ -192,18 +216,27 @@ public void addCounter_disabledMetric() { // set up sink with disabled metric Map enabledMetrics = new HashMap<>(); enabledMetrics.put("client_latency", false); + enabledMetrics.put("active_carrier_pigeons", false); LongCounterMetricInstrument instrument = new LongCounterMetricInstrument(0, "client_latency", "Client latency", "s", Collections.emptyList(), Collections.emptyList(), true); + LongUpDownCounterMetricInstrument longUpDownCounterInstrument = + new LongUpDownCounterMetricInstrument(1, "active_carrier_pigeons", + "Active Carrier Pigeons", "pigeons", + Collections.emptyList(), + Collections.emptyList(), false); + // Create sink sink = new OpenTelemetryMetricSink(testMeter, enabledMetrics, true, Collections.emptyList()); // Invoke updateMeasures - sink.updateMeasures(Arrays.asList(instrument)); + sink.updateMeasures(Arrays.asList(instrument, longUpDownCounterInstrument)); sink.addLongCounter(instrument, 123L, Collections.emptyList(), Collections.emptyList()); + sink.addLongUpDownCounter(longUpDownCounterInstrument, -13L, Collections.emptyList(), + Collections.emptyList()); assertThat(openTelemetryTesting.getMetrics()).isEmpty(); } @@ -377,6 +410,7 @@ public void registerBatchCallback_bothEnabledAndDisabled() { public void recordLabels() { Map enabledMetrics = new HashMap<>(); enabledMetrics.put("client_latency", true); + enabledMetrics.put("ghosts_in_the_wire", true); List optionalLabels = Arrays.asList("optional_label_key_2"); @@ -384,16 +418,24 @@ public void recordLabels() { new LongCounterMetricInstrument(0, "client_latency", "Client latency", "s", ImmutableList.of("required_label_key_1", "required_label_key_2"), ImmutableList.of("optional_label_key_1", "optional_label_key_2"), false); + LongUpDownCounterMetricInstrument longUpDownCounterInstrument = + new LongUpDownCounterMetricInstrument(1, "ghosts_in_the_wire", + "Number of Ghosts Haunting the Wire", "{ghosts}", + ImmutableList.of("required_label_key_1", "required_label_key_2"), + ImmutableList.of("optional_label_key_1", "optional_label_key_2"), false); // Create sink sink = new OpenTelemetryMetricSink(testMeter, enabledMetrics, false, optionalLabels); // Invoke updateMeasures - sink.updateMeasures(Arrays.asList(longCounterInstrument)); + sink.updateMeasures(Arrays.asList(longCounterInstrument, longUpDownCounterInstrument)); sink.addLongCounter(longCounterInstrument, 123L, ImmutableList.of("required_label_value_1", "required_label_value_2"), ImmutableList.of("optional_label_value_1", "optional_label_value_2")); + sink.addLongUpDownCounter(longUpDownCounterInstrument, -400L, + ImmutableList.of("required_label_value_1", "required_label_value_2"), + ImmutableList.of("optional_label_value_1", "optional_label_value_2")); io.opentelemetry.api.common.Attributes expectedAtrributes = io.opentelemetry.api.common.Attributes.of( @@ -417,6 +459,22 @@ public void recordLabels() { point -> point .hasAttributes(expectedAtrributes) - .hasValue(123L)))); + .hasValue(123L))), + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName("ghosts_in_the_wire") + .hasDescription("Number of Ghosts Haunting the Wire") + .hasUnit("{ghosts}") + .hasLongSumSatisfying( + longSum -> + longSum + .hasPointsSatisfying( + point -> + point + .hasAttributes(expectedAtrributes) + .hasValue(-400L)))); + } } diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java index 6128323154d..14139b8e439 100644 --- a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java @@ -25,8 +25,11 @@ import static java.util.Collections.emptyList; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import com.google.common.collect.ImmutableMap; @@ -37,28 +40,52 @@ import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; import io.grpc.ClientStreamTracer; +import io.grpc.Grpc; +import io.grpc.KnownLength; +import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.Server; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerServiceDefinition; import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer.ServerCallInfo; +import io.grpc.ServiceDescriptor; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.FakeClock; +import io.grpc.opentelemetry.GrpcOpenTelemetry.TargetFilter; import io.grpc.opentelemetry.OpenTelemetryMetricsModule.CallAttemptsTracerFactory; import io.grpc.opentelemetry.internal.OpenTelemetryConstants; +import io.grpc.stub.ClientCalls; +import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.GrpcServerRule; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.baggage.Baggage; +import io.opentelemetry.api.baggage.propagation.W3CBaggagePropagator; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.metrics.DoubleHistogram; import io.opentelemetry.api.metrics.Meter; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import io.opentelemetry.context.propagation.ContextPropagators; +import io.opentelemetry.sdk.OpenTelemetrySdk; import io.opentelemetry.sdk.common.InstrumentationScopeInfo; +import io.opentelemetry.sdk.metrics.data.MetricData; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import java.io.IOException; import java.io.InputStream; import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; +import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -79,10 +106,9 @@ public class OpenTelemetryMetricsModuleTest { private static final CallOptions.Key CUSTOM_OPTION = CallOptions.Key.createWithDefault("option1", "default"); private static final CallOptions CALL_OPTIONS = - CallOptions.DEFAULT.withOption(CUSTOM_OPTION, "customvalue"); + CallOptions.DEFAULT.withOption(NAME_RESOLUTION_DELAYED, 10L); private static final ClientStreamTracer.StreamInfo STREAM_INFO = - ClientStreamTracer.StreamInfo.newBuilder() - .setCallOptions(CallOptions.DEFAULT.withOption(NAME_RESOLUTION_DELAYED, 10L)).build(); + ClientStreamTracer.StreamInfo.newBuilder().setCallOptions(CALL_OPTIONS).build(); private static final String CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME = "grpc.client.attempt.started"; private static final String CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME = "grpc.client.attempt.duration"; @@ -91,6 +117,11 @@ public class OpenTelemetryMetricsModuleTest { private static final String CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE = "grpc.client.attempt.rcvd_total_compressed_message_size"; private static final String CLIENT_CALL_DURATION = "grpc.client.call.duration"; + private static final String CLIENT_CALL_RETRIES = "grpc.client.call.retries"; + private static final String CLIENT_CALL_TRANSPARENT_RETRIES = + "grpc.client.call.transparent_retries"; + private static final String CLIENT_CALL_HEDGES = "grpc.client.call.hedges"; + private static final String CLIENT_CALL_RETRY_DELAY = "grpc.client.call.retry_delay"; private static final String SERVER_CALL_COUNT = "grpc.server.call.started"; private static final String SERVER_CALL_DURATION = "grpc.server.call.duration"; private static final String SERVER_CALL_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE @@ -107,7 +138,7 @@ public class OpenTelemetryMetricsModuleTest { { 0L, 1024L, 2048L, 4096L, 16384L, 65536L, 262144L, 1048576L, 4194304L, 16777216L, 67108864L, 268435456L, 1073741824L, 4294967296L }; - private static final class StringInputStream extends InputStream { + private static final class StringInputStream extends InputStream implements KnownLength { final String string; StringInputStream(String string) { @@ -118,6 +149,11 @@ private static final class StringInputStream extends InputStream { public int read() { throw new UnsupportedOperationException("should not be called"); } + + @Override + public int available() throws IOException { + return string == null ? 0 : string.length(); + } } private static final MethodDescriptor.Marshaller MARSHALLER = @@ -136,6 +172,8 @@ public String parse(InputStream stream) { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Rule + public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + @Rule public final GrpcServerRule grpcServerRule = new GrpcServerRule().directExecutor(); @Rule public final OpenTelemetryRule openTelemetryTesting = OpenTelemetryRule.create(); @@ -146,6 +184,9 @@ public String parse(InputStream stream) { @Captor private ArgumentCaptor statusCaptor; + private Server server; + private ManagedChannel channel; + private final FakeClock fakeClock = new FakeClock(); private final MethodDescriptor method = MethodDescriptor.newBuilder() @@ -164,6 +205,17 @@ public String parse(InputStream stream) { public void setUp() throws Exception { testMeter = openTelemetryTesting.getOpenTelemetry() .getMeter(OpenTelemetryConstants.INSTRUMENTATION_SCOPE); + + } + + @After + public void tearDown() { + if (channel != null) { + channel.shutdownNow(); + } + if (server != null) { + server.shutdownNow(); + } } @Test @@ -186,7 +238,7 @@ public ServerCall.Listener startCall( }).build()); final AtomicReference capturedCallOptions = new AtomicReference<>(); - ClientInterceptor callOptionsCatureInterceptor = new ClientInterceptor() { + ClientInterceptor callOptionsCaptureInterceptor = new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { @@ -196,10 +248,11 @@ public ClientCall interceptCall( }; Channel interceptedChannel = ClientInterceptors.intercept( - grpcServerRule.getChannel(), callOptionsCatureInterceptor, + grpcServerRule.getChannel(), callOptionsCaptureInterceptor, module.getClientInterceptor("target:///")); ClientCall call; - call = interceptedChannel.newCall(method, CALL_OPTIONS); + call = interceptedChannel.newCall( + method, CallOptions.DEFAULT.withOption(CUSTOM_OPTION, "customvalue")); assertEquals("customvalue", capturedCallOptions.get().getOption(CUSTOM_OPTION)); assertEquals(1, capturedCallOptions.get().getStreamTracerFactories().size()); @@ -228,7 +281,8 @@ public void clientBasicMetrics() { enabledMetricsMap, disableDefaultMetrics); OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory(module, target, method.getFullMethodName(), emptyList()); + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); Metadata headers = new Metadata(); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); @@ -273,7 +327,7 @@ public void clientBasicMetrics() { tracer.inboundMessage(1); tracer.inboundWireSize(154); tracer.streamClosed(Status.OK); - callAttemptsTracerFactory.callEnded(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); io.opentelemetry.api.common.Attributes clientAttributes = io.opentelemetry.api.common.Attributes.of( @@ -370,6 +424,89 @@ public void clientBasicMetrics() { .hasBucketCounts(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)))); + + assertThat(openTelemetryTesting.getMetrics()) + .extracting("name") + .doesNotContain( + CLIENT_CALL_RETRIES, + CLIENT_CALL_TRANSPARENT_RETRIES, + CLIENT_CALL_HEDGES, + CLIENT_CALL_RETRY_DELAY); + } + + @Test + public void clientBasicMetrics_withRetryMetricsEnabled_shouldRecordZeroOrBeAbsent() { + // Explicitly enable the retry metrics + Map enabledMetrics = ImmutableMap.of( + CLIENT_CALL_RETRIES, true, + CLIENT_CALL_TRANSPARENT_RETRIES, true, + CLIENT_CALL_HEDGES, true, + CLIENT_CALL_RETRY_DELAY, true + ); + + String target = "target:///"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetrics, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + + fakeClock.forwardTime(30, TimeUnit.MILLISECONDS); + tracer.outboundHeaders(); + fakeClock.forwardTime(100, TimeUnit.MILLISECONDS); + tracer.outboundMessage(0); + tracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); + + io.opentelemetry.api.common.Attributes finalAttributes + = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE), + metric -> assertThat(metric).hasName(CLIENT_CALL_DURATION), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_RETRY_DELAY) + .hasHistogramSatisfying( + histogram -> + histogram.hasPointsSatisfying( + point -> + point + .hasSum(0) + .hasCount(1) + .hasAttributes(finalAttributes))) + + ); + + List optionalMetricNames = Arrays.asList( + CLIENT_CALL_RETRIES, + CLIENT_CALL_TRANSPARENT_RETRIES, + CLIENT_CALL_HEDGES); + + for (String metricName : optionalMetricNames) { + Optional metric = openTelemetryTesting.getMetrics().stream() + .filter(m -> m.getName().equals(metricName)) + .findFirst(); + if (metric.isPresent()) { + assertThat(metric.get()) + .hasHistogramSatisfying( + histogram -> + histogram.hasPointsSatisfying( + point -> + point + .hasSum(0) + .hasCount(1) + .hasAttributes(finalAttributes))); + } + } } // This test is only unit-testing the metrics recording logic. The retry behavior is faked. @@ -380,8 +517,8 @@ public void recordAttemptMetrics() { enabledMetricsMap, disableDefaultMetrics); OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, - method.getFullMethodName(), emptyList()); + new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, CALL_OPTIONS, + method.getFullMethodName(), emptyList(), Context.root()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); @@ -696,7 +833,7 @@ public void recordAttemptMetrics() { fakeClock.forwardTime(24, MILLISECONDS); // RPC succeeded tracer.streamClosed(Status.OK); - callAttemptsTracerFactory.callEnded(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); io.opentelemetry.api.common.Attributes clientAttributes2 = io.opentelemetry.api.common.Attributes.of( @@ -823,6 +960,182 @@ public void recordAttemptMetrics() { .hasBucketBoundaries(sizeBuckets)))); } + @Test + public void recordAttemptMetrics_withRetryMetricsEnabled() { + Map enabledMetrics = ImmutableMap.of( + CLIENT_CALL_RETRIES, true, + CLIENT_CALL_TRANSPARENT_RETRIES, true, + CLIENT_CALL_HEDGES, true, + CLIENT_CALL_RETRY_DELAY, true + ); + + String target = "dns:///example.com"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetrics, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, CALL_OPTIONS, + method.getFullMethodName(), emptyList(), Context.root()); + + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + fakeClock.forwardTime(154, TimeUnit.MILLISECONDS); + tracer.streamClosed(Status.UNAVAILABLE); + + fakeClock.forwardTime(1000, TimeUnit.MILLISECONDS); + tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + fakeClock.forwardTime(100, TimeUnit.MILLISECONDS); + tracer.streamClosed(Status.NOT_FOUND); + + fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); + tracer = callAttemptsTracerFactory.newClientStreamTracer( + STREAM_INFO.toBuilder().setIsTransparentRetry(true).build(), new Metadata()); + fakeClock.forwardTime(32, MILLISECONDS); + tracer.streamClosed(Status.UNAVAILABLE); + + fakeClock.forwardTime(10, MILLISECONDS); + tracer = callAttemptsTracerFactory.newClientStreamTracer( + STREAM_INFO.toBuilder().setIsTransparentRetry(true).build(), new Metadata()); + tracer.inboundWireSize(33); + fakeClock.forwardTime(24, MILLISECONDS); + tracer.streamClosed(Status.OK); // RPC succeeded + + // --- The overall call ends --- + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); + + // Define attributes for assertions + io.opentelemetry.api.common.Attributes finalAttributes + = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + // FINAL ASSERTION BLOCK + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + // Default metrics + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE), + metric -> assertThat(metric).hasName(CLIENT_CALL_DURATION), + + // --- Assertions for the retry metrics --- + metric -> assertThat(metric) + .hasName(CLIENT_CALL_RETRIES) + .hasUnit("{retry}") + .hasHistogramSatisfying(histogram -> histogram.hasPointsSatisfying( + point -> point + .hasCount(1) + .hasSum(1) // We faked one standard retry + .hasAttributes(finalAttributes))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_TRANSPARENT_RETRIES) + .hasUnit("{transparent_retry}") + .hasHistogramSatisfying(histogram -> histogram.hasPointsSatisfying( + point -> point + .hasCount(1) + .hasSum(2) // We faked two transparent retries + .hasAttributes(finalAttributes))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_RETRY_DELAY) + .hasUnit("s") + .hasHistogramSatisfying(histogram -> histogram.hasPointsSatisfying( + point -> point + .hasCount(1) + .hasSum(1.02) // 1000ms + 10ms + 10ms + .hasAttributes(finalAttributes))) + ); + } + + @Test + public void recordAttemptMetrics_withHedgedCalls() { + // Enable the retry metrics, including hedges + Map enabledMetrics = ImmutableMap.of( + CLIENT_CALL_RETRIES, true, + CLIENT_CALL_TRANSPARENT_RETRIES, true, + CLIENT_CALL_HEDGES, true, + CLIENT_CALL_RETRY_DELAY, true + ); + + String target = "dns:///example.com"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetrics, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, CALL_OPTIONS, + method.getFullMethodName(), emptyList(), Context.root()); + + // Create a StreamInfo specifically for hedged attempts + final ClientStreamTracer.StreamInfo hedgedStreamInfo = + STREAM_INFO.toBuilder().setIsHedging(true).build(); + + // --- First attempt starts --- + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + + // --- Faking a hedged attempt --- + fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); // Hedging delay + ClientStreamTracer hedgeTracer1 = + callAttemptsTracerFactory.newClientStreamTracer(hedgedStreamInfo, new Metadata()); + + // --- Faking a second hedged attempt --- + fakeClock.forwardTime(20, TimeUnit.MILLISECONDS); // Another hedging delay + ClientStreamTracer hedgeTracer2 = + callAttemptsTracerFactory.newClientStreamTracer(hedgedStreamInfo, new Metadata()); + + // --- Let the attempts resolve --- + fakeClock.forwardTime(50, TimeUnit.MILLISECONDS); + // Initial attempt is cancelled because a hedge will succeed + tracer.streamClosed(Status.CANCELLED); + hedgeTracer1.streamClosed(Status.UNAVAILABLE); // First hedge fails + + fakeClock.forwardTime(30, TimeUnit.MILLISECONDS); + hedgeTracer2.streamClosed(Status.OK); // Second hedge succeeds + + // --- The overall call ends --- + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); + + // Define attributes for assertions + io.opentelemetry.api.common.Attributes finalAttributes + = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + // FINAL ASSERTION BLOCK + // We expect 7 metrics: 5 default + hedges + retry_delay. + // Retries and transparent_retries are 0 and will not be reported. + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + // Default metrics + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE), + metric -> assertThat(metric).hasName(CLIENT_CALL_DURATION), + + // --- Assertions for the NEW metrics --- + metric -> assertThat(metric) + .hasName(CLIENT_CALL_HEDGES) + .hasUnit("{hedge}") + .hasHistogramSatisfying(histogram -> histogram.hasPointsSatisfying( + point -> point + .hasCount(1) + .hasSum(2) + .hasAttributes(finalAttributes))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_RETRY_DELAY) + .hasUnit("s") + .hasHistogramSatisfying( + histogram -> + histogram.hasPointsSatisfying( + point -> + point + .hasCount(1) + .hasSum(0) + .hasAttributes(finalAttributes))) + ); + } + @Test public void clientStreamNeverCreatedStillRecordMetrics() { String target = "dns:///foo.example.com"; @@ -830,11 +1143,11 @@ public void clientStreamNeverCreatedStillRecordMetrics() { enabledMetricsMap, disableDefaultMetrics); OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, - method.getFullMethodName(), emptyList()); + new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, CALL_OPTIONS, + method.getFullMethodName(), emptyList(), Context.root()); fakeClock.forwardTime(3000, MILLISECONDS); Status status = Status.DEADLINE_EXCEEDED.withDescription("5 seconds"); - callAttemptsTracerFactory.callEnded(status); + callAttemptsTracerFactory.callEnded(status, CALL_OPTIONS); io.opentelemetry.api.common.Attributes attemptStartedAttributes = io.opentelemetry.api.common.Attributes.of( @@ -937,9 +1250,11 @@ public void clientLocalityMetrics_present() { OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, enabledMetricsMap, disableDefaultMetrics); OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( - fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality"), emptyList()); + fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality"), + emptyList()); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory(module, target, method.getFullMethodName(), emptyList()); + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); @@ -948,7 +1263,7 @@ public void clientLocalityMetrics_present() { tracer.addOptionalLabel("grpc.lb.locality", "the-moon"); tracer.addOptionalLabel("grpc.lb.foo", "thats-no-moon"); tracer.streamClosed(Status.OK); - callAttemptsTracerFactory.callEnded(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( TARGET_KEY, target, @@ -1005,14 +1320,16 @@ public void clientLocalityMetrics_missing() { OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, enabledMetricsMap, disableDefaultMetrics); OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( - fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality"), emptyList()); + fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality"), + emptyList()); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory(module, target, method.getFullMethodName(), emptyList()); + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); tracer.streamClosed(Status.OK); - callAttemptsTracerFactory.callEnded(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( TARGET_KEY, target, @@ -1063,6 +1380,236 @@ public void clientLocalityMetrics_missing() { point -> point.hasAttributes(clientAttributes)))); } + @Test + public void clientBackendServiceMetrics_present() { + String target = "target:///"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( + fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.backend_service"), + emptyList()); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); + + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer.addOptionalLabel("grpc.lb.foo", "unimportant"); + tracer.addOptionalLabel("grpc.lb.backend_service", "should-be-overwritten"); + tracer.addOptionalLabel("grpc.lb.backend_service", "the-moon"); + tracer.addOptionalLabel("grpc.lb.foo", "thats-no-moon"); + tracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); + + io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + io.opentelemetry.api.common.Attributes clientAttributes + = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName(), + STATUS_KEY, Status.Code.OK.toString()); + + io.opentelemetry.api.common.Attributes clientAttributesWithBackendService + = clientAttributes.toBuilder() + .put(AttributeKey.stringKey("grpc.lb.backend_service"), "the-moon") + .build(); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttributes(attributes))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithBackendService))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithBackendService))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithBackendService))), + metric -> + assertThat(metric) + .hasName(CLIENT_CALL_DURATION) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributes)))); + } + + @Test + public void clientBackendServiceMetrics_missing() { + String target = "target:///"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( + fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.backend_service"), + emptyList()); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); + + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); + + io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + io.opentelemetry.api.common.Attributes clientAttributes + = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName(), + STATUS_KEY, Status.Code.OK.toString()); + + io.opentelemetry.api.common.Attributes clientAttributesWithBackendService + = clientAttributes.toBuilder() + .put(AttributeKey.stringKey("grpc.lb.backend_service"), "") + .build(); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttributes(attributes))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithBackendService))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithBackendService))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithBackendService))), + metric -> + assertThat(metric) + .hasName(CLIENT_CALL_DURATION) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributes)))); + } + + @Test + public void customLabel_present() { + Map enabledMetrics = ImmutableMap.of( + CLIENT_CALL_HEDGES, true, + CLIENT_CALL_RETRIES, true, + CLIENT_CALL_RETRY_DELAY, true, + CLIENT_CALL_TRANSPARENT_RETRIES, true + ); + String target = "target:///"; + String customValue = "some-random-value"; + CallOptions callOptions = + STREAM_INFO.getCallOptions().withOption(Grpc.CALL_OPTION_CUSTOM_LABEL, customValue); + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetrics, disableDefaultMetrics); + String customLabel = "grpc.client.call.custom"; + OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( + fakeClock.getStopwatchSupplier(), resource, Arrays.asList(customLabel), + emptyList()); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CallAttemptsTracerFactory( + module, target, callOptions, method.getFullMethodName(), emptyList(), Context.root()); + + ClientStreamTracer.StreamInfo streamInfo = + STREAM_INFO.toBuilder().setCallOptions(callOptions).build(); + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(streamInfo, new Metadata()); + tracer.streamClosed(Status.UNAVAILABLE); + + tracer = callAttemptsTracerFactory.newClientStreamTracer(streamInfo, new Metadata()); + tracer.streamClosed(Status.UNAVAILABLE); + + tracer = callAttemptsTracerFactory.newClientStreamTracer( + streamInfo.toBuilder().setIsTransparentRetry(true).build(), new Metadata()); + tracer.streamClosed(Status.UNAVAILABLE); + + tracer = callAttemptsTracerFactory.newClientStreamTracer( + streamInfo.toBuilder().setIsHedging(true).build(), new Metadata()); + tracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, callOptions); + + AttributeKey attributeKey = AttributeKey.stringKey(customLabel); + + assertThat(sortByName(openTelemetryTesting.getMetrics())) + .satisfiesExactly( + metric -> assertThat(metric) + .hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue), + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue), + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue), + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_DURATION) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_HEDGES) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_RETRIES) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_RETRY_DELAY) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_TRANSPARENT_RETRIES) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue)))); + } + @Test public void serverBasicMetrics() { OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, @@ -1187,12 +1734,151 @@ public void serverBasicMetrics() { } + + @Test + public void targetAttributeFilter_notSet_usesOriginalTarget() { + // Test that when no filter is set, the original target is used + String target = "dns:///example.com"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); + + Channel interceptedChannel = + ClientInterceptors.intercept( + grpcServerRule.getChannel(), module.getClientInterceptor(target)); + + ClientCall call = interceptedChannel.newCall(method, CALL_OPTIONS); + + // Make the call + Metadata headers = new Metadata(); + call.start(mockClientCallListener, headers); + + // End the call + call.halfClose(); + call.request(1); + + io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + assertThat(openTelemetryTesting.getMetrics()) + .anySatisfy( + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasUnit("{attempt}") + .hasLongSumSatisfying( + longSum -> + longSum + .hasPointsSatisfying( + point -> + point + .hasAttributes(attributes)))); + } + + @Test + public void targetAttributeFilter_allowsTarget_usesOriginalTarget() { + // Test that when filter allows the target, the original target is used + String target = "dns:///example.com"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource, + t -> t.contains("example.com")); + + Channel interceptedChannel = + ClientInterceptors.intercept( + grpcServerRule.getChannel(), module.getClientInterceptor(target)); + + ClientCall call = interceptedChannel.newCall(method, CALL_OPTIONS); + + // Make the call + Metadata headers = new Metadata(); + call.start(mockClientCallListener, headers); + + // End the call + call.halfClose(); + call.request(1); + + io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + assertThat(openTelemetryTesting.getMetrics()) + .anySatisfy( + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasUnit("{attempt}") + .hasLongSumSatisfying( + longSum -> + longSum + .hasPointsSatisfying( + point -> + point + .hasAttributes(attributes)))); + } + + @Test + public void targetAttributeFilter_rejectsTarget_mapsToOther() { + // Test that when filter rejects the target, it is mapped to "other" + String target = "dns:///example.com"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource, + t -> t.contains("allowed.com")); + + Channel interceptedChannel = + ClientInterceptors.intercept( + grpcServerRule.getChannel(), module.getClientInterceptor(target)); + + ClientCall call = interceptedChannel.newCall(method, CALL_OPTIONS); + + // Make the call + Metadata headers = new Metadata(); + call.start(mockClientCallListener, headers); + + // End the call + call.halfClose(); + call.request(1); + + io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, "other", + METHOD_KEY, method.getFullMethodName()); + + assertThat(openTelemetryTesting.getMetrics()) + .anySatisfy( + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasUnit("{attempt}") + .hasLongSumSatisfying( + longSum -> + longSum + .hasPointsSatisfying( + point -> + point + .hasAttributes(attributes)))); + } + private OpenTelemetryMetricsModule newOpenTelemetryMetricsModule( OpenTelemetryMetricsResource resource) { return new OpenTelemetryMetricsModule( fakeClock.getStopwatchSupplier(), resource, emptyList(), emptyList()); } + private OpenTelemetryMetricsModule newOpenTelemetryMetricsModule( + OpenTelemetryMetricsResource resource, TargetFilter filter) { + return new OpenTelemetryMetricsModule( + fakeClock.getStopwatchSupplier(), resource, emptyList(), emptyList(), + filter); + } + static class CallInfo extends ServerCallInfo { private final MethodDescriptor methodDescriptor; private final Attributes attributes; @@ -1223,4 +1909,130 @@ public String getAuthority() { return authority; } } + + @Test + public void serverMetrics_recordsBaggage() { + DoubleHistogram mockDurationHistogram = mock(DoubleHistogram.class); + OpenTelemetryMetricsResource mockResource = OpenTelemetryMetricsResource.builder() + .serverCallDurationCounter(mockDurationHistogram) + .build(); + + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(mockResource); + ServerStreamTracer.Factory tracerFactory = module.getServerTracerFactory(); + + Baggage baggage = Baggage.builder() + .put("baggage-key-1", "baggage-val-1") + .build(); + + io.grpc.Context grpcContext = io.grpc.Context.ROOT + .withValue(OpenTelemetryConstants.BAGGAGE_KEY, baggage); + io.grpc.Context previous = grpcContext.attach(); + + ServerStreamTracer tracer; + try { + tracer = tracerFactory.newServerStreamTracer( + method.getFullMethodName(), new Metadata()); + tracer.filterContext(grpcContext); + tracer.serverCallStarted( + new CallInfo<>(method, Attributes.EMPTY, null)); + } finally { + grpcContext.detach(previous); + } + + try (io.opentelemetry.context.Scope scope = Context.root().makeCurrent()) { + tracer.streamClosed(Status.CANCELLED); + } + + ArgumentCaptor contextCaptor = ArgumentCaptor.forClass(Context.class); + verify(mockDurationHistogram).record( + anyDouble(), + any(), + contextCaptor.capture()); + + Baggage capturedBaggage = Baggage.fromContext(contextCaptor.getValue()); + assertNotNull("Captured context should have baggage", capturedBaggage); + assertEquals( + "baggage-val-1", capturedBaggage.getEntryValue("baggage-key-1")); + } + + @Test + public void serverMetrics_recordsBaggage_endToEnd() throws Exception { + DoubleHistogram mockDurationHistogram = mock(DoubleHistogram.class); + OpenTelemetryMetricsResource mockResource = OpenTelemetryMetricsResource.builder() + .serverCallDurationCounter(mockDurationHistogram) + .build(); + + OpenTelemetry openTelemetry = OpenTelemetrySdk + .builder() + .setPropagators(ContextPropagators.create( + W3CBaggagePropagator.getInstance())) + .build(); + + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(mockResource); + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule(openTelemetry); + + String serverName = InProcessServerBuilder.generateName(); + InProcessServerBuilder serverBuilder = InProcessServerBuilder + .forName(serverName).directExecutor(); + + serverBuilder.addStreamTracerFactory(tracingModule.getServerTracerFactory()); + serverBuilder.intercept(tracingModule.getServerSpanPropagationInterceptor()); + serverBuilder.addStreamTracerFactory(module.getServerTracerFactory()); + + serverBuilder.addService(ServerServiceDefinition.builder( + ServiceDescriptor.newBuilder("package1.service2") + .addMethod(method) + .build()) + .addMethod(method, new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + call.sendHeaders(new Metadata()); + call.sendMessage("response"); + call.close(Status.OK, new Metadata()); + return new ServerCall.Listener() { + }; + } + }).build()); + grpcCleanup.register(serverBuilder.build().start()); + + InProcessChannelBuilder channelBuilder = InProcessChannelBuilder + .forName(serverName).directExecutor(); + channelBuilder.intercept(tracingModule.getClientInterceptor()); + channelBuilder.intercept(module.getClientInterceptor(serverName)); + Channel channel = grpcCleanup.register(channelBuilder.intercept(new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + }).build()); + + Baggage baggage = Baggage.builder() + .put("baggage-key-1", "baggage-val-1") + .build(); + + Context otelContext = Context.root().with(baggage); + + try (Scope scope = otelContext.makeCurrent()) { + ClientCalls.blockingUnaryCall(channel, + method, CallOptions.DEFAULT, "request"); + } + + ArgumentCaptor contextCaptor = ArgumentCaptor.forClass(Context.class); + verify(mockDurationHistogram).record( + anyDouble(), + any(), + contextCaptor.capture()); + + Baggage capturedBaggage = Baggage.fromContext(contextCaptor.getValue()); + assertNotNull("Captured context should have baggage", capturedBaggage); + assertEquals( + "baggage-val-1", capturedBaggage.getEntryValue("baggage-key-1")); + } + + private static List sortByName(List metrics) { + metrics.sort((m1, m2) -> m1.getName().compareTo(m2.getName())); + return metrics; + } } diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java index 68cba17e802..e6759aadb1e 100644 --- a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java @@ -17,13 +17,17 @@ package io.grpc.opentelemetry; import static io.grpc.ClientStreamTracer.NAME_RESOLUTION_DELAYED; -import static io.grpc.opentelemetry.OpenTelemetryTracingModule.OTEL_TRACING_SCOPE_NAME; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.BAGGAGE_KEY; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -38,22 +42,38 @@ import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; import io.grpc.ClientStreamTracer; +import io.grpc.KnownLength; +import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.NoopServerCall; +import io.grpc.Server; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; import io.grpc.ServerServiceDefinition; import io.grpc.ServerStreamTracer; import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.opentelemetry.OpenTelemetryTracingModule.CallAttemptsTracerFactory; +import io.grpc.opentelemetry.internal.OpenTelemetryConstants; +import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.GrpcServerRule; import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.baggage.Baggage; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.SpanBuilder; +import io.opentelemetry.api.trace.SpanContext; import io.opentelemetry.api.trace.SpanId; import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.api.trace.TraceFlags; import io.opentelemetry.api.trace.TraceId; +import io.opentelemetry.api.trace.TraceState; import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.api.trace.TracerBuilder; +import io.opentelemetry.api.trace.TracerProvider; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; import io.opentelemetry.context.propagation.ContextPropagators; @@ -61,6 +81,7 @@ import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.EventData; import io.opentelemetry.sdk.trace.data.SpanData; +import java.io.IOException; import java.io.InputStream; import java.util.Arrays; import java.util.List; @@ -90,7 +111,7 @@ public class OpenTelemetryTracingModuleTest { private static final CallOptions CALL_OPTIONS = CallOptions.DEFAULT.withOption(CUSTOM_OPTION, "customvalue"); - private static class StringInputStream extends InputStream { + private static class StringInputStream extends InputStream implements KnownLength { final String string; StringInputStream(String string) { @@ -99,10 +120,15 @@ private static class StringInputStream extends InputStream { @Override public int read() { - // InProcessTransport doesn't actually read bytes from the InputStream. The InputStream is + // InProcessTransport doesn't actually read bytes from the InputStream. The InputStream is // passed to the InProcess server and consumed by MARSHALLER.parse(). throw new UnsupportedOperationException("Should not be called"); } + + @Override + public int available() throws IOException { + return string == null ? 0 : string.length(); + } } private static final MethodDescriptor.Marshaller MARSHALLER = @@ -130,6 +156,8 @@ public String parse(InputStream stream) { public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); @Rule public final GrpcServerRule grpcServerRule = new GrpcServerRule().directExecutor(); + @Rule + public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); private Tracer tracerRule; @Mock private Tracer mockTracer; @@ -156,8 +184,15 @@ public String parse(InputStream stream) { @Before public void setUp() { - tracerRule = openTelemetryRule.getOpenTelemetry().getTracer(OTEL_TRACING_SCOPE_NAME); - when(mockOpenTelemetry.getTracer(OTEL_TRACING_SCOPE_NAME)).thenReturn(mockTracer); + tracerRule = openTelemetryRule.getOpenTelemetry().getTracer( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE); + TracerProvider mockTracerProvider = mock(TracerProvider.class); + when(mockOpenTelemetry.getTracerProvider()).thenReturn(mockTracerProvider); + TracerBuilder mockTracerBuilder = mock(TracerBuilder.class); + when(mockTracerProvider.tracerBuilder(OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .thenReturn(mockTracerBuilder); + when(mockTracerBuilder.setInstrumentationVersion(any())).thenReturn(mockTracerBuilder); + when(mockTracerBuilder.build()).thenReturn(mockTracer); when(mockOpenTelemetry.getPropagators()).thenReturn(ContextPropagators.create(mockPropagator)); when(mockSpanBuilder.startSpan()).thenReturn(mockAttemptSpan); when(mockSpanBuilder.setParent(any())).thenReturn(mockSpanBuilder); @@ -203,7 +238,7 @@ public void clientBasicTracingMocking() { List events = eventNameCaptor.getAllValues(); List attributes = attributesCaptor.getAllValues(); assertEquals( - "Outbound message sent" , + "Outbound message" , events.get(0)); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -213,7 +248,7 @@ public void clientBasicTracingMocking() { attributes.get(0)); assertEquals( - "Outbound message sent" , + "Outbound message" , events.get(1)); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -285,7 +320,7 @@ public void clientBasicTracingRule() { assertTrue(clientSpanEvents.get(0).getAttributes().isEmpty()); assertEquals( - "Inbound message received" , + "Inbound message" , clientSpanEvents.get(1).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -295,7 +330,7 @@ public void clientBasicTracingRule() { clientSpanEvents.get(1).getAttributes()); assertEquals( - "Inbound message received" , + "Inbound message" , clientSpanEvents.get(2).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -314,7 +349,7 @@ public void clientBasicTracingRule() { assertTrue(clientSpanEvents.get(0).getAttributes().isEmpty()); assertEquals( - "Outbound message sent" , + "Outbound message" , attemptSpanEvents.get(1).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -324,7 +359,7 @@ public void clientBasicTracingRule() { attemptSpanEvents.get(1).getAttributes()); assertEquals( - "Outbound message sent" , + "Outbound message" , attemptSpanEvents.get(2).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -451,7 +486,8 @@ public ClientCall interceptCall( @Test public void clientStreamNeverCreatedStillRecordTracing() { - OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule(mockOpenTelemetry); + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); CallAttemptsTracerFactory callTracer = tracingModule.newClientCallTracer(mockClientSpan, method); @@ -489,7 +525,7 @@ public void serverBasicTracingNoHeaders() { List events = spans.get(0).getEvents(); assertEquals(events.size(), 4); assertEquals( - "Outbound message sent" , + "Outbound message" , events.get(0).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -500,7 +536,7 @@ public void serverBasicTracingNoHeaders() { events.get(0).getAttributes()); assertEquals( - "Outbound message sent" , + "Outbound message" , events.get(1).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -520,7 +556,7 @@ public void serverBasicTracingNoHeaders() { events.get(2).getAttributes()); assertEquals( - "Inbound message received" , + "Inbound message" , events.get(3).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -570,6 +606,266 @@ public void grpcTraceBinPropagator() { Span.fromContext(contextArgumentCaptor.getValue()).getSpanContext()); } + @Test + public void testServerParentSpanPropagation() throws Exception { + final AtomicReference applicationSpan = new AtomicReference<>(); + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + ServerServiceDefinition serviceDefinition = + ServerServiceDefinition.builder("package1.service2").addMethod( + method, new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + applicationSpan.set(Span.fromContext(Context.current())); + call.sendHeaders(new Metadata()); + call.sendMessage("Hello"); + call.close( + Status.PERMISSION_DENIED.withDescription("No you don't"), new Metadata()); + return mockServerCallListener; + } + }).build(); + + Server server = InProcessServerBuilder.forName("test-server-span") + .addService( + ServerInterceptors.intercept(serviceDefinition, + tracingModule.getServerSpanPropagationInterceptor())) + .addStreamTracerFactory(tracingModule.getServerTracerFactory()) + .directExecutor().build().start(); + grpcCleanupRule.register(server); + + ManagedChannel channel = InProcessChannelBuilder.forName("test-server-span") + .directExecutor().build(); + grpcCleanupRule.register(channel); + + Span parentSpan = tracerRule.spanBuilder("test-parent-span").startSpan(); + try (Scope scope = Context.current().with(parentSpan).makeCurrent()) { + Channel interceptedChannel = + ClientInterceptors.intercept( + channel, tracingModule.getClientInterceptor()); + ClientCall call = interceptedChannel.newCall(method, CALL_OPTIONS); + Metadata headers = new Metadata(); + call.start(mockClientCallListener, headers); + + // End the call + call.halfClose(); + call.request(1); + parentSpan.end(); + } + + verify(mockClientCallListener).onClose(statusCaptor.capture(), any(Metadata.class)); + Status rpcStatus = statusCaptor.getValue(); + assertEquals(rpcStatus.getCode(), Status.Code.PERMISSION_DENIED); + assertEquals(rpcStatus.getDescription(), "No you don't"); + assertEquals(applicationSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + + List spans = openTelemetryRule.getSpans(); + assertEquals(spans.size(), 4); + SpanData clientSpan = spans.get(2); + SpanData attemptSpan = spans.get(1); + + assertEquals(clientSpan.getName(), "Sent.package1.service2.method3"); + assertTrue(clientSpan.hasEnded()); + assertEquals(clientSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(clientSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + + assertEquals(attemptSpan.getName(), "Attempt.package1.service2.method3"); + assertTrue(attemptSpan.hasEnded()); + assertEquals(attemptSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(attemptSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + + SpanData serverSpan = spans.get(0); + assertEquals(serverSpan.getName(), "Recv.package1.service2.method3"); + assertTrue(serverSpan.hasEnded()); + assertEquals(serverSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(serverSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + } + + @Test + public void serverSpanPropagationInterceptor() throws Exception { + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + Server server = InProcessServerBuilder.forName("test-span-propagation-interceptor") + .directExecutor().build().start(); + grpcCleanupRule.register(server); + final AtomicReference callbackSpan = new AtomicReference<>(); + ServerCall.Listener getContextListener = new ServerCall.Listener() { + @Override + public void onMessage(Integer message) { + callbackSpan.set(Span.fromContext(Context.current())); + } + + @Override + public void onHalfClose() { + callbackSpan.set(Span.fromContext(Context.current())); + } + + @Override + public void onCancel() { + callbackSpan.set(Span.fromContext(Context.current())); + } + + @Override + public void onComplete() { + callbackSpan.set(Span.fromContext(Context.current())); + } + }; + ServerInterceptor interceptor = tracingModule.getServerSpanPropagationInterceptor(); + @SuppressWarnings("unchecked") + ServerCallHandler handler = mock(ServerCallHandler.class); + when(handler.startCall(any(), any())).thenReturn(getContextListener); + ServerCall call = new NoopServerCall<>(); + Metadata metadata = new Metadata(); + ServerCall.Listener listener = interceptor.interceptCall(call, metadata, handler); + verify(handler).startCall(same(call), same(metadata)); + listener.onMessage(1); + assertEquals(callbackSpan.get(), Span.getInvalid()); + listener.onReady(); + assertEquals(callbackSpan.get(), Span.getInvalid()); + listener.onCancel(); + assertEquals(callbackSpan.get(), Span.getInvalid()); + listener.onHalfClose(); + assertEquals(callbackSpan.get(), Span.getInvalid()); + listener.onComplete(); + assertEquals(callbackSpan.get(), Span.getInvalid()); + + Span parentSpan = tracerRule.spanBuilder("parent-span").startSpan(); + io.grpc.Context context = io.grpc.Context.current().withValue( + tracingModule.otelSpan, parentSpan); + io.grpc.Context previous = context.attach(); + try { + listener = interceptor.interceptCall(call, metadata, handler); + verify(handler, times(2)).startCall(same(call), same(metadata)); + listener.onMessage(1); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + listener.onReady(); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + listener.onCancel(); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + listener.onHalfClose(); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + listener.onComplete(); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + } finally { + context.detach(previous); + } + } + + /** + * Tests that baggage from the initial context is propagated + * to the context active during the next handler's execution. + */ + @Test + public void testBaggageIsPropagatedToHandlerContext() { + // 1. ARRANGE + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + ServerInterceptor interceptor = tracingModule.getServerSpanPropagationInterceptor(); + + // Create mocks for the gRPC call chain + @SuppressWarnings("unchecked") + ServerCallHandler mockHandler = mock(ServerCallHandler.class); + @SuppressWarnings("unchecked") + ServerCall.Listener mockListener = mock(ServerCall.Listener.class); + ServerCall mockCall = new NoopServerCall<>(); + Metadata mockHeaders = new Metadata(); + + // Create a non-null Span (required to pass the first 'if' check) + Span testSpan = Span.wrap( + SpanContext.create("time-period", "star-wars", + TraceFlags.getSampled(), TraceState.getDefault())); + + // Create the test Baggage + Baggage testBaggage = Baggage.builder().put("best-bot", "R2D2").build(); + + // Create the initial gRPC context that the interceptor will read from + io.grpc.Context initialGrpcContext = io.grpc.Context.current() + .withValue(tracingModule.otelSpan, testSpan) + .withValue(BAGGAGE_KEY, testBaggage); + + // This AtomicReference will capture the Baggage from *within* the handler + final AtomicReference capturedBaggage = new AtomicReference<>(); + + // Stub the handler to capture the *current* context when it's called + doAnswer(invocation -> { + // Baggage.current() gets baggage from io.opentelemetry.context.Context.current() + capturedBaggage.set(Baggage.current()); + return mockListener; + }).when(mockHandler).startCall(any(), any()); + + // 2. ACT + // Run the interceptCall method within the prepared context + io.grpc.Context previous = initialGrpcContext.attach(); + try { + interceptor.interceptCall(mockCall, mockHeaders, mockHandler); + } finally { + initialGrpcContext.detach(previous); + } + + // 3. ASSERT + // Verify the next handler was called + verify(mockHandler).startCall(same(mockCall), same(mockHeaders)); + + // Check the baggage that was captured + assertNotNull("Baggage should not be null in handler context", capturedBaggage.get()); + assertEquals("Baggage was not correctly propagated to the handler's context", + "R2D2", capturedBaggage.get().getEntryValue("best-bot")); + } + + /** + * Tests that the interceptor proceeds correctly if baggage is null or empty. + */ + @Test + public void testNullBaggageIsHandledGracefully() { + // 1. ARRANGE + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + ServerInterceptor interceptor = tracingModule.getServerSpanPropagationInterceptor(); + + @SuppressWarnings("unchecked") + ServerCallHandler mockHandler = mock(ServerCallHandler.class); + @SuppressWarnings("unchecked") + ServerCall.Listener mockListener = mock(ServerCall.Listener.class); + ServerCall mockCall = new NoopServerCall<>(); + Metadata mockHeaders = new Metadata(); + + Span testSpan = Span.getInvalid(); // A non-null span + + // No baggage is set in the context + io.grpc.Context initialGrpcContext = io.grpc.Context.current() + .withValue(tracingModule.otelSpan, testSpan); + + final AtomicReference capturedBaggage = new AtomicReference<>(); + + // Stub the handler to capture the *current* context when it's called + doAnswer(invocation -> { + // Baggage.current() gets baggage from io.opentelemetry.context.Context.current() + capturedBaggage.set(Baggage.current()); + return mockListener; + }).when(mockHandler).startCall(any(), any()); + + // 2. ACT + io.grpc.Context previous = initialGrpcContext.attach(); + try { + interceptor.interceptCall(mockCall, mockHeaders, mockHandler); + } finally { + initialGrpcContext.detach(previous); + } + + // 3. ASSERT + verify(mockHandler).startCall(same(mockCall), same(mockHeaders)); + + // Baggage should be null in the downstream context + assertEquals("Baggage should be empty when not provided", + Baggage.empty(), capturedBaggage.get()); + } + @Test public void generateTraceSpanName() { assertEquals( diff --git a/protobuf-lite/BUILD.bazel b/protobuf-lite/BUILD.bazel index dad794e8b58..97a5e492d80 100644 --- a/protobuf-lite/BUILD.bazel +++ b/protobuf-lite/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( diff --git a/protobuf-lite/build.gradle b/protobuf-lite/build.gradle index 11a49d4816d..c1e5b51ae35 100644 --- a/protobuf-lite/build.gradle +++ b/protobuf-lite/build.gradle @@ -17,8 +17,16 @@ dependencies { testImplementation project(':grpc-core') - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("jar").configure { @@ -31,7 +39,7 @@ tasks.named("compileTestJava").configure { options.compilerArgs += [ "-Xlint:-cast" ] - options.errorprone.excludedPaths = ".*/build/generated/source/proto/.*" + options.errorprone.excludedPaths = ".*/build/generated/sources/proto/.*" } protobuf { diff --git a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java index 7e33fc67622..ef4b16bd476 100644 --- a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java +++ b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java @@ -89,12 +89,11 @@ public static Marshaller marshaller(T defaultInstance /** * Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a - * custom limit for the recursion depth. Any negative number will leave the limit to its default + * custom limit for the recursion depth. Any negative number will leave the limit as its default * value as defined by the protobuf library. * * @since 1.56.0 */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10108") public static Marshaller marshallerWithRecursionLimit( T defaultInstance, int recursionLimit) { return new MessageMarshaller<>(defaultInstance, recursionLimit); diff --git a/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java b/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java index 5c25cb3b309..204264b016d 100644 --- a/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java +++ b/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java @@ -16,6 +16,7 @@ package io.grpc.protobuf.lite; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -43,9 +44,7 @@ import java.io.IOException; import java.io.InputStream; import java.util.Arrays; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -53,9 +52,6 @@ @RunWith(JUnit4.class) public class ProtoLiteUtilsTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); - private final Marshaller marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance()); private Type proto = Type.newBuilder().setName("name").build(); @@ -214,10 +210,9 @@ public void metadataMarshaller_invalid() { @Test public void extensionRegistry_notNull() { - thrown.expect(NullPointerException.class); - thrown.expectMessage("newRegistry"); - - ProtoLiteUtils.setExtensionRegistry(null); + NullPointerException e = assertThrows(NullPointerException.class, + () -> ProtoLiteUtils.setExtensionRegistry(null)); + assertThat(e).hasMessageThat().isEqualTo("newRegistry"); } @Test diff --git a/protobuf/BUILD.bazel b/protobuf/BUILD.bazel index 724c78ca6ee..a31f8b6f6f5 100644 --- a/protobuf/BUILD.bazel +++ b/protobuf/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( @@ -12,6 +13,7 @@ java_library( "@com_google_protobuf//:protobuf_java", artifact("com.google.api.grpc:proto-google-common-protos"), artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), ], ) diff --git a/protobuf/build.gradle b/protobuf/build.gradle index c88ae836e0f..c477e41dceb 100644 --- a/protobuf/build.gradle +++ b/protobuf/build.gradle @@ -31,8 +31,16 @@ dependencies { exclude group: 'com.google.protobuf', module: 'protobuf-javalite' } - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("javadoc").configure { diff --git a/protobuf/src/main/java/io/grpc/protobuf/ProtoMethodDescriptorSupplier.java b/protobuf/src/main/java/io/grpc/protobuf/ProtoMethodDescriptorSupplier.java index e5b2f38e3c0..e7cd3ed336f 100644 --- a/protobuf/src/main/java/io/grpc/protobuf/ProtoMethodDescriptorSupplier.java +++ b/protobuf/src/main/java/io/grpc/protobuf/ProtoMethodDescriptorSupplier.java @@ -16,8 +16,8 @@ package io.grpc.protobuf; +import com.google.errorprone.annotations.CheckReturnValue; import com.google.protobuf.Descriptors.MethodDescriptor; -import javax.annotation.CheckReturnValue; /** * Provides access to the underlying proto service method descriptor. diff --git a/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java b/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java index 933d598996c..d403789eb5f 100644 --- a/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java +++ b/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java @@ -18,7 +18,6 @@ import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; -import io.grpc.ExperimentalApi; import io.grpc.Metadata; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.protobuf.lite.ProtoLiteUtils; @@ -58,12 +57,11 @@ public static Marshaller marshaller(final T defaultInstan /** * Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a - * custom limit for the recursion depth. Any negative number will leave the limit to its default + * custom limit for the recursion depth. Any negative number will leave the limit as its default * value as defined by the protobuf library. * * @since 1.56.0 */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10108") public static Marshaller marshallerWithRecursionLimit(T defaultInstance, int recursionLimit) { return ProtoLiteUtils.marshallerWithRecursionLimit(defaultInstance, recursionLimit); diff --git a/repositories.bzl b/repositories.bzl index 455e9dcf3ca..0a09cece070 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -12,40 +12,42 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # GRPC_DEPS_START IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "com.google.android:annotations:4.1.1.4", - "com.google.api.grpc:proto-google-common-protos:2.29.0", - "com.google.auth:google-auth-library-credentials:1.23.0", - "com.google.auth:google-auth-library-oauth2-http:1.23.0", + "com.google.api.grpc:proto-google-common-protos:2.64.1", + "com.google.auth:google-auth-library-credentials:1.42.1", + "com.google.auth:google-auth-library-oauth2-http:1.42.1", "com.google.auto.value:auto-value-annotations:1.11.0", "com.google.auto.value:auto-value:1.11.0", "com.google.code.findbugs:jsr305:3.0.2", - "com.google.code.gson:gson:2.11.0", - "com.google.errorprone:error_prone_annotations:2.28.0", + "com.google.code.gson:gson:2.13.2", + "com.google.errorprone:error_prone_annotations:2.48.0", "com.google.guava:failureaccess:1.0.1", - "com.google.guava:guava:33.2.1-android", - "com.google.re2j:re2j:1.7", - "com.google.truth:truth:1.4.2", + "com.google.guava:guava:33.5.0-android", + "com.google.re2j:re2j:1.8", + "com.google.s2a.proto.v2:s2a-proto:0.1.3", + "com.google.truth:truth:1.4.5", "com.squareup.okhttp:okhttp:2.7.5", "com.squareup.okio:okio:2.10.0", # 3.0+ needs swapping to -jvm; need work to avoid flag-day - "io.netty:netty-buffer:4.1.110.Final", - "io.netty:netty-codec-http2:4.1.110.Final", - "io.netty:netty-codec-http:4.1.110.Final", - "io.netty:netty-codec-socks:4.1.110.Final", - "io.netty:netty-codec:4.1.110.Final", - "io.netty:netty-common:4.1.110.Final", - "io.netty:netty-handler-proxy:4.1.110.Final", - "io.netty:netty-handler:4.1.110.Final", - "io.netty:netty-resolver:4.1.110.Final", - "io.netty:netty-tcnative-boringssl-static:2.0.65.Final", - "io.netty:netty-tcnative-classes:2.0.65.Final", - "io.netty:netty-transport-native-epoll:jar:linux-x86_64:4.1.110.Final", - "io.netty:netty-transport-native-unix-common:4.1.110.Final", - "io.netty:netty-transport:4.1.110.Final", + "io.netty:netty-buffer:4.1.132.Final", + "io.netty:netty-codec-http2:4.1.132.Final", + "io.netty:netty-codec-http:4.1.132.Final", + "io.netty:netty-codec-socks:4.1.132.Final", + "io.netty:netty-codec:4.1.132.Final", + "io.netty:netty-common:4.1.132.Final", + "io.netty:netty-handler-proxy:4.1.132.Final", + "io.netty:netty-handler:4.1.132.Final", + "io.netty:netty-resolver:4.1.132.Final", + "io.netty:netty-tcnative-boringssl-static:2.0.75.Final", + "io.netty:netty-tcnative-classes:2.0.75.Final", + "io.netty:netty-transport-native-epoll:jar:linux-x86_64:4.1.132.Final", + "io.netty:netty-transport-native-unix-common:4.1.132.Final", + "io.netty:netty-transport:4.1.132.Final", "io.opencensus:opencensus-api:0.31.0", "io.opencensus:opencensus-contrib-grpc-metrics:0.31.0", "io.perfmark:perfmark-api:0.27.0", "junit:junit:4.13.2", - "org.apache.tomcat:annotations-api:6.0.53", - "org.codehaus.mojo:animal-sniffer-annotations:1.24", + "org.mockito:mockito-core:4.4.0", + "org.checkerframework:checker-qual:3.49.5", + "org.codehaus.mojo:animal-sniffer-annotations:1.27", ] # GRPC_DEPS_END @@ -80,43 +82,17 @@ IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS = { "io.grpc:grpc-rls": "@io_grpc_grpc_java//rls", "io.grpc:grpc-services": "@io_grpc_grpc_java//services:services_maven", "io.grpc:grpc-stub": "@io_grpc_grpc_java//stub", + "io.grpc:grpc-s2a": "@io_grpc_grpc_java//s2a", "io.grpc:grpc-testing": "@io_grpc_grpc_java//testing", "io.grpc:grpc-xds": "@io_grpc_grpc_java//xds:xds_maven", "io.grpc:grpc-util": "@io_grpc_grpc_java//util", } -def grpc_java_repositories(bzlmod = False): +def grpc_java_repositories(): """Imports dependencies for grpc-java.""" - if not bzlmod and not native.existing_rule("dev_cel"): - http_archive( - name = "dev_cel", - strip_prefix = "cel-spec-0.15.0", - sha256 = "3ee09eb69dbe77722e9dee23dc48dc2cd9f765869fcf5ffb1226587c81791a0b", - urls = [ - "https://github.com/google/cel-spec/archive/refs/tags/v0.15.0.tar.gz", - ], - ) - if not native.existing_rule("com_github_cncf_xds"): - http_archive( - name = "com_github_cncf_xds", - strip_prefix = "xds-024c85f92f20cab567a83acc50934c7f9711d124", - sha256 = "5f403aa681711500ca8e62387be3e37d971977db6e88616fc21862a406430649", - urls = [ - "https://github.com/cncf/xds/archive/024c85f92f20cab567a83acc50934c7f9711d124.tar.gz", - ], - ) - if not bzlmod and not native.existing_rule("com_github_grpc_grpc"): - http_archive( - name = "com_github_grpc_grpc", - strip_prefix = "grpc-1.46.0", - sha256 = "67423a4cd706ce16a88d1549297023f0f9f0d695a96dd684adc21e67b021f9bc", - urls = [ - "https://github.com/grpc/grpc/archive/v1.46.0.tar.gz", - ], - ) - if not bzlmod and not native.existing_rule("com_google_protobuf"): + if not native.existing_rule("com_google_protobuf"): com_google_protobuf() - if not bzlmod and not native.existing_rule("com_google_googleapis"): + if not native.existing_rule("com_google_googleapis"): http_archive( name = "com_google_googleapis", sha256 = "49930468563dd48283e8301e8d4e71436bf6d27ac27c235224cc1a098710835d", @@ -125,25 +101,14 @@ def grpc_java_repositories(bzlmod = False): "https://github.com/googleapis/googleapis/archive/ca1372c6d7bcb199638ebfdb40d2b2660bab7b88.tar.gz", ], ) - if not bzlmod and not native.existing_rule("io_bazel_rules_go"): - http_archive( - name = "io_bazel_rules_go", - sha256 = "ab21448cef298740765f33a7f5acee0607203e4ea321219f2a4c85a6e0fb0a27", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.32.0/rules_go-v0.32.0.zip", - "https://github.com/bazelbuild/rules_go/releases/download/v0.32.0/rules_go-v0.32.0.zip", - ], - ) if not native.existing_rule("io_grpc_grpc_proto"): io_grpc_grpc_proto() - if not native.existing_rule("envoy_api"): + if not native.existing_rule("bazel_jar_jar"): http_archive( - name = "envoy_api", - sha256 = "cb7cd388eaa297320d392c872ceb82571dee71f4b6f1c4546b0c0a399636f523", - strip_prefix = "data-plane-api-874e3aa8c3aa5086b6bffa2166e0e0077bb32f71", - urls = [ - "https://github.com/envoyproxy/data-plane-api/archive/874e3aa8c3aa5086b6bffa2166e0e0077bb32f71.tar.gz", - ], + name = "bazel_jar_jar", + sha256 = "3117f913c732142a795551f530d02c9157b9ea895e6b2de0fbb5af54f03040a5", + strip_prefix = "bazel_jar_jar-0.1.6", + url = "https://github.com/bazeltools/bazel_jar_jar/releases/download/v0.1.6/bazel_jar_jar-v0.1.6.tar.gz", ) def com_google_protobuf(): @@ -152,9 +117,9 @@ def com_google_protobuf(): # This statement defines the @com_google_protobuf repo. http_archive( name = "com_google_protobuf", - sha256 = "9bd87b8280ef720d3240514f884e56a712f2218f0d693b48050c836028940a42", - strip_prefix = "protobuf-25.1", - urls = ["https://github.com/protocolbuffers/protobuf/releases/download/v25.1/protobuf-25.1.tar.gz"], + sha256 = "bc670a4e34992c175137ddda24e76562bb928f849d712a0e3c2fb2e19249bea1", + strip_prefix = "protobuf-33.4", + urls = ["https://github.com/protocolbuffers/protobuf/releases/download/v33.4/protobuf-33.4.tar.gz"], ) def io_grpc_grpc_proto(): @@ -164,8 +129,3 @@ def io_grpc_grpc_proto(): strip_prefix = "grpc-proto-4f245d272a28a680606c0739753506880cf33b5f", urls = ["https://github.com/grpc/grpc-proto/archive/4f245d272a28a680606c0739753506880cf33b5f.zip"], ) - -def _grpc_java_repositories_extension(_): - grpc_java_repositories(bzlmod = True) - -grpc_java_repositories_extension = module_extension(implementation = _grpc_java_repositories_extension) diff --git a/rls/BUILD.bazel b/rls/BUILD.bazel index 10a5e22524a..70c17a9c8b6 100644 --- a/rls/BUILD.bazel +++ b/rls/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") load("//:java_grpc_library.bzl", "java_grpc_library") @@ -19,6 +20,7 @@ java_library( "@io_grpc_grpc_proto//:rls_java_proto", artifact("com.google.auto.value:auto-value-annotations"), artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), ], ) diff --git a/rls/build.gradle b/rls/build.gradle index 0629dce64c1..10b1d5fc371 100644 --- a/rls/build.gradle +++ b/rls/build.gradle @@ -22,7 +22,6 @@ dependencies { libraries.auto.value.annotations, libraries.guava annotationProcessor libraries.auto.value - compileOnly libraries.javax.annotation testImplementation libraries.truth, project(':grpc-grpclb'), project(':grpc-inprocess'), @@ -30,7 +29,11 @@ dependencies { project(':grpc-testing-proto'), testFixtures(project(':grpc-api')), testFixtures(project(':grpc-core')) - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } tasks.named("compileJava").configure { @@ -46,7 +49,7 @@ tasks.named("compileJava").configure { tasks.named("javadoc").configure { // Do not publish javadoc since currently there is no public API. - failOnError false // no public or protected classes found to document + failOnError = false // no public or protected classes found to document exclude 'io/grpc/lookup/v1/**' exclude 'io/grpc/rls/*Provider.java' exclude 'io/grpc/rls/internal/**' diff --git a/rls/src/generated/main/grpc/io/grpc/lookup/v1/RouteLookupServiceGrpc.java b/rls/src/generated/main/grpc/io/grpc/lookup/v1/RouteLookupServiceGrpc.java index d7334b942ff..be060e576a4 100644 --- a/rls/src/generated/main/grpc/io/grpc/lookup/v1/RouteLookupServiceGrpc.java +++ b/rls/src/generated/main/grpc/io/grpc/lookup/v1/RouteLookupServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/lookup/v1/rls.proto") @io.grpc.stub.annotations.GrpcGenerated public final class RouteLookupServiceGrpc { @@ -60,6 +57,21 @@ public RouteLookupServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptio return RouteLookupServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static RouteLookupServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public RouteLookupServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RouteLookupServiceBlockingV2Stub(channel, callOptions); + } + }; + return RouteLookupServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -147,6 +159,33 @@ public void routeLookup(io.grpc.lookup.v1.RouteLookupRequest request, /** * A stub to allow clients to do synchronous rpc calls to service RouteLookupService. */ + public static final class RouteLookupServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private RouteLookupServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected RouteLookupServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RouteLookupServiceBlockingV2Stub(channel, callOptions); + } + + /** + *

+     * Lookup returns a target for a single key.
+     * 
+ */ + public io.grpc.lookup.v1.RouteLookupResponse routeLookup(io.grpc.lookup.v1.RouteLookupRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getRouteLookupMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service RouteLookupService. + */ public static final class RouteLookupServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private RouteLookupServiceBlockingStub( diff --git a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java index d0661ba3be8..a2846fd04c8 100644 --- a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java +++ b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java @@ -28,9 +28,12 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.CheckReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityState; +import io.grpc.Grpc; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -51,7 +54,6 @@ import io.grpc.lookup.v1.RouteLookupServiceGrpc; import io.grpc.lookup.v1.RouteLookupServiceGrpc.RouteLookupServiceStub; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; -import io.grpc.rls.LbPolicyConfiguration.ChildLbStatusListener; import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper; import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory; import io.grpc.rls.LruCache.EvictionListener; @@ -59,6 +61,7 @@ import io.grpc.rls.RlsProtoConverters.RouteLookupResponseConverter; import io.grpc.rls.RlsProtoData.RouteLookupConfig; import io.grpc.rls.RlsProtoData.RouteLookupRequest; +import io.grpc.rls.RlsProtoData.RouteLookupRequestKey; import io.grpc.rls.RlsProtoData.RouteLookupResponse; import io.grpc.stub.StreamObserver; import io.grpc.util.ForwardingLoadBalancerHelper; @@ -73,9 +76,7 @@ import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -112,7 +113,7 @@ final class CachingRlsLbClient { private final Future periodicCleaner; // any RPC on the fly will cached in this map @GuardedBy("lock") - private final Map pendingCallCache = new HashMap<>(); + private final Map pendingCallCache = new HashMap<>(); private final ScheduledExecutorService scheduledExecutorService; private final Ticker ticker; @@ -132,6 +133,7 @@ final class CachingRlsLbClient { @GuardedBy("lock") private final RefCountedChildPolicyWrapperFactory refCountedChildPolicyWrapperFactory; private final ChannelLogger logger; + private final ChildPolicyWrapper fallbackChildPolicyWrapper; static { MetricInstrumentRegistry metricInstrumentRegistry @@ -140,20 +142,22 @@ final class CachingRlsLbClient { "grpc.lb.rls.default_target_picks", "EXPERIMENTAL. Number of LB picks sent to the default target", "{pick}", Arrays.asList("grpc.target", "grpc.lb.rls.server_target", - "grpc.lb.rls.data_plane_target", "grpc.lb.pick_result"), Collections.emptyList(), + "grpc.lb.rls.data_plane_target", "grpc.lb.pick_result"), + Arrays.asList("grpc.client.call.custom"), false); TARGET_PICKS_COUNTER = metricInstrumentRegistry.registerLongCounter("grpc.lb.rls.target_picks", "EXPERIMENTAL. Number of LB picks sent to each RLS target. Note that if the default " + "target is also returned by the RLS server, RPCs sent to that target from the cache " + "will be counted in this metric, not in grpc.rls.default_target_picks.", "{pick}", Arrays.asList("grpc.target", "grpc.lb.rls.server_target", "grpc.lb.rls.data_plane_target", - "grpc.lb.pick_result"), Collections.emptyList(), + "grpc.lb.pick_result"), + Arrays.asList("grpc.client.call.custom"), false); FAILED_PICKS_COUNTER = metricInstrumentRegistry.registerLongCounter("grpc.lb.rls.failed_picks", "EXPERIMENTAL. Number of LB picks failed due to either a failed RLS request or the " + "RLS channel being throttled", "{pick}", Arrays.asList("grpc.target", "grpc.lb.rls.server_target"), - Collections.emptyList(), false); + Arrays.asList("grpc.client.call.custom"), false); CACHE_ENTRIES_GAUGE = metricInstrumentRegistry.registerLongGauge("grpc.lb.rls.cache_entries", "EXPERIMENTAL. Number of entries in the RLS cache", "{entry}", Arrays.asList("grpc.target", "grpc.lb.rls.server_target", "grpc.lb.rls.instance_uuid"), @@ -215,6 +219,35 @@ private CachingRlsLbClient(Builder builder) { rlsChannelBuilder.disableServiceConfigLookUp(); } rlsChannel = rlsChannelBuilder.build(); + Runnable rlsServerConnectivityStateChangeHandler = new Runnable() { + private boolean wasInTransientFailure; + @Override + public void run() { + ConnectivityState currentState = rlsChannel.getState(false); + if (currentState == ConnectivityState.TRANSIENT_FAILURE) { + wasInTransientFailure = true; + } else if (wasInTransientFailure && currentState == ConnectivityState.READY) { + wasInTransientFailure = false; + synchronized (lock) { + boolean anyBackoffsCanceled = false; + for (CacheEntry value : linkedHashLruCache.values()) { + if (value instanceof BackoffCacheEntry) { + if (((BackoffCacheEntry) value).scheduledFuture.cancel(false)) { + anyBackoffsCanceled = true; + } + } + } + if (anyBackoffsCanceled) { + // Cache updated. updateBalancingState() to reattempt picks + helper.triggerPendingRpcProcessing(); + } + } + } + rlsChannel.notifyWhenStateChanged(currentState, this); + } + }; + rlsChannel.notifyWhenStateChanged( + ConnectivityState.IDLE, rlsServerConnectivityStateChangeHandler); rlsStub = RouteLookupServiceGrpc.newStub(rlsChannel); childLbResolvedAddressFactory = checkNotNull(builder.resolvedAddressFactory, "resolvedAddressFactory"); @@ -224,8 +257,14 @@ private CachingRlsLbClient(Builder builder) { refCountedChildPolicyWrapperFactory = new RefCountedChildPolicyWrapperFactory( lbPolicyConfig.getLoadBalancingPolicy(), childLbResolvedAddressFactory, - childLbHelperProvider, - new BackoffRefreshListener()); + childLbHelperProvider); + // TODO(creamsoup) wait until lb is ready + String defaultTarget = lbPolicyConfig.getRouteLookupConfig().defaultTarget(); + if (defaultTarget != null && !defaultTarget.isEmpty()) { + fallbackChildPolicyWrapper = refCountedChildPolicyWrapperFactory.createOrGet(defaultTarget); + } else { + fallbackChildPolicyWrapper = null; + } gaugeRegistration = helper.getMetricRecorder() .registerBatchCallback(new BatchCallback() { @@ -255,6 +294,13 @@ void init() { } } + Status acceptResolvedAddressFactory(ResolvedAddressFactory childLbResolvedAddressFactory) { + synchronized (lock) { + return refCountedChildPolicyWrapperFactory.acceptResolvedAddressFactory( + childLbResolvedAddressFactory); + } + } + /** * Convert the status to UNAVAILABLE and enhance the error message. * @param status status as provided by server @@ -277,43 +323,48 @@ private void periodicClean() { /** Populates async cache entry for new request. */ @GuardedBy("lock") private CachedRouteLookupResponse asyncRlsCall( - RouteLookupRequest request, @Nullable BackoffPolicy backoffPolicy) { - logger.log(ChannelLogLevel.DEBUG, "Making an async call to RLS"); + RouteLookupRequestKey routeLookupRequestKey, @Nullable BackoffPolicy backoffPolicy, + RouteLookupRequest.Reason routeLookupReason) { if (throttler.shouldThrottle()) { - logger.log(ChannelLogLevel.DEBUG, "Request is throttled"); + logger.log(ChannelLogLevel.DEBUG, "[RLS Entry {0}] Throttled RouteLookup", + routeLookupRequestKey); // Cache updated, but no need to call updateBalancingState because no RPCs were queued waiting // on this result return CachedRouteLookupResponse.backoffEntry(createBackOffEntry( - request, Status.RESOURCE_EXHAUSTED.withDescription("RLS throttled"), backoffPolicy)); + routeLookupRequestKey, Status.RESOURCE_EXHAUSTED.withDescription("RLS throttled"), + backoffPolicy)); } final SettableFuture response = SettableFuture.create(); - io.grpc.lookup.v1.RouteLookupRequest routeLookupRequest = REQUEST_CONVERTER.convert(request); - logger.log(ChannelLogLevel.DEBUG, "Sending RouteLookupRequest: {0}", routeLookupRequest); + io.grpc.lookup.v1.RouteLookupRequest routeLookupRequest = REQUEST_CONVERTER.convert( + RouteLookupRequest.create(routeLookupRequestKey.keyMap(), routeLookupReason)); + logger.log(ChannelLogLevel.DEBUG, + "[RLS Entry {0}] Starting RouteLookup: {1}", routeLookupRequestKey, routeLookupRequest); rlsStub.withDeadlineAfter(callTimeoutNanos, TimeUnit.NANOSECONDS) .routeLookup( routeLookupRequest, new StreamObserver() { @Override public void onNext(io.grpc.lookup.v1.RouteLookupResponse value) { - logger.log(ChannelLogLevel.DEBUG, "Received RouteLookupResponse: {0}", value); + logger.log(ChannelLogLevel.DEBUG, + "[RLS Entry {0}] RouteLookup succeeded: {1}", routeLookupRequestKey, value); response.set(RESPONSE_CONVERTER.reverse().convert(value)); } @Override public void onError(Throwable t) { - logger.log(ChannelLogLevel.DEBUG, "Error looking up route:", t); + logger.log(ChannelLogLevel.DEBUG, + "[RLS Entry {0}] RouteLookup failed: {1}", routeLookupRequestKey, t); response.setException(t); throttler.registerBackendResponse(true); } @Override public void onCompleted() { - logger.log(ChannelLogLevel.DEBUG, "routeLookup call completed"); throttler.registerBackendResponse(false); } }); return CachedRouteLookupResponse.pendingResponse( - createPendingEntry(request, response, backoffPolicy)); + createPendingEntry(routeLookupRequestKey, response, backoffPolicy)); } /** @@ -322,32 +373,30 @@ public void onCompleted() { * changed after the return. */ @CheckReturnValue - final CachedRouteLookupResponse get(final RouteLookupRequest request) { - logger.log(ChannelLogLevel.DEBUG, "Acquiring lock to get cached entry"); + final CachedRouteLookupResponse get(final RouteLookupRequestKey routeLookupRequestKey) { synchronized (lock) { - logger.log(ChannelLogLevel.DEBUG, "Acquired lock to get cached entry"); final CacheEntry cacheEntry; - cacheEntry = linkedHashLruCache.read(request); - if (cacheEntry == null) { - logger.log(ChannelLogLevel.DEBUG, "No cache entry found, making a new RLS request"); - PendingCacheEntry pendingEntry = pendingCallCache.get(request); + cacheEntry = linkedHashLruCache.read(routeLookupRequestKey); + if (cacheEntry == null + || (cacheEntry instanceof BackoffCacheEntry + && !((BackoffCacheEntry) cacheEntry).isInBackoffPeriod())) { + PendingCacheEntry pendingEntry = pendingCallCache.get(routeLookupRequestKey); if (pendingEntry != null) { return CachedRouteLookupResponse.pendingResponse(pendingEntry); } - return asyncRlsCall(request, /* backoffPolicy= */ null); + return asyncRlsCall(routeLookupRequestKey, cacheEntry instanceof BackoffCacheEntry + ? ((BackoffCacheEntry) cacheEntry).backoffPolicy : null, + RouteLookupRequest.Reason.REASON_MISS); } if (cacheEntry instanceof DataCacheEntry) { // cache hit, initiate async-refresh if entry is staled - logger.log(ChannelLogLevel.DEBUG, "Cache hit for the request"); DataCacheEntry dataEntry = ((DataCacheEntry) cacheEntry); if (dataEntry.isStaled(ticker.read())) { - logger.log(ChannelLogLevel.DEBUG, "Cache entry is stale"); dataEntry.maybeRefresh(); } return CachedRouteLookupResponse.dataEntry((DataCacheEntry) cacheEntry); } - logger.log(ChannelLogLevel.DEBUG, "Cache hit for a backup entry"); return CachedRouteLookupResponse.backoffEntry((BackoffCacheEntry) cacheEntry); } } @@ -373,13 +422,14 @@ void requestConnection() { @GuardedBy("lock") private PendingCacheEntry createPendingEntry( - RouteLookupRequest request, + RouteLookupRequestKey routeLookupRequestKey, ListenableFuture pendingCall, @Nullable BackoffPolicy backoffPolicy) { - PendingCacheEntry entry = new PendingCacheEntry(request, pendingCall, backoffPolicy); + PendingCacheEntry entry = new PendingCacheEntry(routeLookupRequestKey, pendingCall, + backoffPolicy); // Add the entry to the map before adding the Listener, because the listener removes the // entry from the map - pendingCallCache.put(request, entry); + pendingCallCache.put(routeLookupRequestKey, entry); // Beware that the listener can run immediately on the current thread pendingCall.addListener(() -> pendingRpcComplete(entry), MoreExecutors.directExecutor()); return entry; @@ -387,17 +437,18 @@ private PendingCacheEntry createPendingEntry( private void pendingRpcComplete(PendingCacheEntry entry) { synchronized (lock) { - boolean clientClosed = pendingCallCache.remove(entry.request) == null; + boolean clientClosed = pendingCallCache.remove(entry.routeLookupRequestKey) == null; if (clientClosed) { return; } try { - createDataEntry(entry.request, Futures.getDone(entry.pendingCall)); + createDataEntry(entry.routeLookupRequestKey, Futures.getDone(entry.pendingCall)); // Cache updated. DataCacheEntry constructor indirectly calls updateBalancingState() to // reattempt picks when the child LB is done connecting } catch (Exception e) { - createBackOffEntry(entry.request, Status.fromThrowable(e), entry.backoffPolicy); + createBackOffEntry(entry.routeLookupRequestKey, Status.fromThrowable(e), + entry.backoffPolicy); // Cache updated. updateBalancingState() to reattempt picks helper.triggerPendingRpcProcessing(); } @@ -406,33 +457,35 @@ private void pendingRpcComplete(PendingCacheEntry entry) { @GuardedBy("lock") private DataCacheEntry createDataEntry( - RouteLookupRequest request, RouteLookupResponse routeLookupResponse) { + RouteLookupRequestKey routeLookupRequestKey, RouteLookupResponse routeLookupResponse) { logger.log( ChannelLogLevel.DEBUG, - "Transition to data cache: routeLookupResponse={0}", - routeLookupResponse); - DataCacheEntry entry = new DataCacheEntry(request, routeLookupResponse); + "[RLS Entry {0}] Transition to data cache: routeLookupResponse={1}", + routeLookupRequestKey, routeLookupResponse); + DataCacheEntry entry = new DataCacheEntry(routeLookupRequestKey, routeLookupResponse); // Constructor for DataCacheEntry causes updateBalancingState, but the picks can't happen until // this cache update because the lock is held - linkedHashLruCache.cacheAndClean(request, entry); + linkedHashLruCache.cacheAndClean(routeLookupRequestKey, entry); return entry; } @GuardedBy("lock") - private BackoffCacheEntry createBackOffEntry( - RouteLookupRequest request, Status status, @Nullable BackoffPolicy backoffPolicy) { - logger.log(ChannelLogLevel.DEBUG, "Transition to back off: status={0}", status); + private BackoffCacheEntry createBackOffEntry(RouteLookupRequestKey routeLookupRequestKey, + Status status, @Nullable BackoffPolicy backoffPolicy) { if (backoffPolicy == null) { backoffPolicy = backoffProvider.get(); } long delayNanos = backoffPolicy.nextBackoffNanos(); - BackoffCacheEntry entry = new BackoffCacheEntry(request, status, backoffPolicy); + logger.log( + ChannelLogLevel.DEBUG, + "[RLS Entry {0}] Transition to back off: status={1}, delayNanos={2}", + routeLookupRequestKey, status, delayNanos); + BackoffCacheEntry entry = new BackoffCacheEntry(routeLookupRequestKey, status, backoffPolicy, + ticker.read() + delayNanos * 2); // Lock is held, so the task can't execute before the assignment entry.scheduledFuture = scheduledExecutorService.schedule( () -> refreshBackoffEntry(entry), delayNanos, TimeUnit.NANOSECONDS); - linkedHashLruCache.cacheAndClean(request, entry); - logger.log(ChannelLogLevel.DEBUG, "BackoffCacheEntry created with a delay of {0} nanos", - delayNanos); + linkedHashLruCache.cacheAndClean(routeLookupRequestKey, entry); return entry; } @@ -443,9 +496,8 @@ private void refreshBackoffEntry(BackoffCacheEntry entry) { // Future was previously cancelled return; } - logger.log(ChannelLogLevel.DEBUG, "Calling RLS for transition to pending"); - linkedHashLruCache.invalidate(entry.request); - asyncRlsCall(entry.request, entry.backoffPolicy); + // Cache updated. updateBalancingState() to reattempt picks + helper.triggerPendingRpcProcessing(); } } @@ -578,15 +630,15 @@ public String toString() { /** A pending cache entry when the async RouteLookup RPC is still on the fly. */ static final class PendingCacheEntry { private final ListenableFuture pendingCall; - private final RouteLookupRequest request; + private final RouteLookupRequestKey routeLookupRequestKey; @Nullable private final BackoffPolicy backoffPolicy; PendingCacheEntry( - RouteLookupRequest request, + RouteLookupRequestKey routeLookupRequestKey, ListenableFuture pendingCall, @Nullable BackoffPolicy backoffPolicy) { - this.request = checkNotNull(request, "request"); + this.routeLookupRequestKey = checkNotNull(routeLookupRequestKey, "request"); this.pendingCall = checkNotNull(pendingCall, "pendingCall"); this.backoffPolicy = backoffPolicy; } @@ -594,7 +646,7 @@ static final class PendingCacheEntry { @Override public String toString() { return MoreObjects.toStringHelper(this) - .add("request", request) + .add("routeLookupRequestKey", routeLookupRequestKey) .toString(); } } @@ -602,10 +654,10 @@ public String toString() { /** Common cache entry data for {@link RlsAsyncLruCache}. */ abstract static class CacheEntry { - protected final RouteLookupRequest request; + protected final RouteLookupRequestKey routeLookupRequestKey; - CacheEntry(RouteLookupRequest request) { - this.request = checkNotNull(request, "request"); + CacheEntry(RouteLookupRequestKey routeLookupRequestKey) { + this.routeLookupRequestKey = checkNotNull(routeLookupRequestKey, "request"); } abstract int getSizeBytes(); @@ -628,8 +680,9 @@ final class DataCacheEntry extends CacheEntry { private final List childPolicyWrappers; // GuardedBy CachingRlsLbClient.lock - DataCacheEntry(RouteLookupRequest request, final RouteLookupResponse response) { - super(request); + DataCacheEntry(RouteLookupRequestKey routeLookupRequestKey, + final RouteLookupResponse response) { + super(routeLookupRequestKey); this.response = checkNotNull(response, "response"); checkState(!response.targets().isEmpty(), "No targets returned by RLS"); childPolicyWrappers = @@ -657,13 +710,14 @@ final class DataCacheEntry extends CacheEntry { */ void maybeRefresh() { synchronized (lock) { // Lock is already held, but ErrorProne can't tell - if (pendingCallCache.containsKey(request)) { + if (pendingCallCache.containsKey(routeLookupRequestKey)) { // pending already requested - logger.log(ChannelLogLevel.DEBUG, - "A pending refresh request already created, no need to proceed with refresh"); return; } - asyncRlsCall(request, /* backoffPolicy= */ null); + logger.log(ChannelLogLevel.DEBUG, + "[RLS Entry {0}] Cache entry is stale, refreshing", routeLookupRequestKey); + asyncRlsCall(routeLookupRequestKey, /* backoffPolicy= */ null, + RouteLookupRequest.Reason.REASON_STALE); } } @@ -733,7 +787,7 @@ void cleanup() { @Override public String toString() { return MoreObjects.toStringHelper(this) - .add("request", request) + .add("request", routeLookupRequestKey) .add("response", response) .add("expireTime", expireTime) .add("staleTime", staleTime) @@ -750,12 +804,15 @@ private static final class BackoffCacheEntry extends CacheEntry { private final Status status; private final BackoffPolicy backoffPolicy; + private final long expiryTimeNanos; private Future scheduledFuture; - BackoffCacheEntry(RouteLookupRequest request, Status status, BackoffPolicy backoffPolicy) { - super(request); + BackoffCacheEntry(RouteLookupRequestKey routeLookupRequestKey, Status status, + BackoffPolicy backoffPolicy, long expiryTimeNanos) { + super(routeLookupRequestKey); this.status = checkNotNull(status, "status"); this.backoffPolicy = checkNotNull(backoffPolicy, "backoffPolicy"); + this.expiryTimeNanos = expiryTimeNanos; } Status getStatus() { @@ -767,9 +824,13 @@ int getSizeBytes() { return OBJ_OVERHEAD_B * 3 + Long.SIZE + 8; // 3 java objects, 1 long and a boolean } + boolean isInBackoffPeriod() { + return !scheduledFuture.isDone(); + } + @Override - boolean isExpired(long now) { - return scheduledFuture.isDone(); + boolean isExpired(long nowNanos) { + return nowNanos > expiryTimeNanos; } @Override @@ -780,7 +841,7 @@ void cleanup() { @Override public String toString() { return MoreObjects.toStringHelper(this) - .add("request", request) + .add("request", routeLookupRequestKey) .add("status", status) .toString(); } @@ -799,7 +860,7 @@ static final class Builder { private Throttler throttler = new HappyThrottler(); private ResolvedAddressFactory resolvedAddressFactory; private Ticker ticker = Ticker.systemTicker(); - private EvictionListener evictionListener; + private EvictionListener evictionListener; private BackoffPolicy.Provider backoffProvider = new ExponentialBackoffPolicy.Provider(); Builder setHelper(Helper helper) { @@ -833,7 +894,7 @@ Builder setTicker(Ticker ticker) { } Builder setEvictionListener( - @Nullable EvictionListener evictionListener) { + @Nullable EvictionListener evictionListener) { this.evictionListener = evictionListener; return this; } @@ -855,17 +916,17 @@ CachingRlsLbClient build() { * CacheEntry#cleanup()} after original {@link EvictionListener} is finished. */ private static final class AutoCleaningEvictionListener - implements EvictionListener { + implements EvictionListener { - private final EvictionListener delegate; + private final EvictionListener delegate; AutoCleaningEvictionListener( - @Nullable EvictionListener delegate) { + @Nullable EvictionListener delegate) { this.delegate = delegate; } @Override - public void onEviction(RouteLookupRequest key, CacheEntry value, EvictionType cause) { + public void onEviction(RouteLookupRequestKey key, CacheEntry value, EvictionType cause) { if (delegate != null) { delegate.onEviction(key, value, cause); } @@ -890,29 +951,29 @@ public void registerBackendResponse(boolean throttled) { /** Implementation of {@link LinkedHashLruCache} for RLS. */ private static final class RlsAsyncLruCache - extends LinkedHashLruCache { + extends LinkedHashLruCache { private final RlsLbHelper helper; RlsAsyncLruCache(long maxEstimatedSizeBytes, - @Nullable EvictionListener evictionListener, + @Nullable EvictionListener evictionListener, Ticker ticker, RlsLbHelper helper) { super(maxEstimatedSizeBytes, evictionListener, ticker); this.helper = checkNotNull(helper, "helper"); } @Override - protected boolean isExpired(RouteLookupRequest key, CacheEntry value, long nowNanos) { + protected boolean isExpired(RouteLookupRequestKey key, CacheEntry value, long nowNanos) { return value.isExpired(nowNanos); } @Override - protected int estimateSizeOf(RouteLookupRequest key, CacheEntry value) { + protected int estimateSizeOf(RouteLookupRequestKey key, CacheEntry value) { return value.getSizeBytes(); } @Override protected boolean shouldInvalidateEldestEntry( - RouteLookupRequest eldestKey, CacheEntry eldestValue, long now) { + RouteLookupRequestKey eldestKey, CacheEntry eldestValue, long now) { if (!eldestValue.isOldEnoughToBeEvicted(now)) { return false; } @@ -921,7 +982,7 @@ protected boolean shouldInvalidateEldestEntry( return this.estimatedSizeBytes() > this.estimatedMaxSizeBytes(); } - public CacheEntry cacheAndClean(RouteLookupRequest key, CacheEntry value) { + public CacheEntry cacheAndClean(RouteLookupRequestKey key, CacheEntry value) { CacheEntry newEntry = cache(key, value); // force cleanup if new entry pushed cache over max size (in bytes) @@ -932,35 +993,6 @@ public CacheEntry cacheAndClean(RouteLookupRequest key, CacheEntry value) { } } - /** - * LbStatusListener refreshes {@link BackoffCacheEntry} when lb state is changed to {@link - * ConnectivityState#READY} from {@link ConnectivityState#TRANSIENT_FAILURE}. - */ - private final class BackoffRefreshListener implements ChildLbStatusListener { - - @Nullable - private ConnectivityState prevState = null; - - @Override - public void onStatusChanged(ConnectivityState newState) { - logger.log(ChannelLogLevel.DEBUG, "LB status changed to: {0}", newState); - if (prevState == ConnectivityState.TRANSIENT_FAILURE - && newState == ConnectivityState.READY) { - logger.log(ChannelLogLevel.DEBUG, "Transitioning from TRANSIENT_FAILURE to READY"); - logger.log(ChannelLogLevel.DEBUG, "Acquiring lock force refresh backoff cache entries"); - synchronized (lock) { - logger.log(ChannelLogLevel.DEBUG, "Lock acquired for refreshing backoff cache entries"); - for (CacheEntry value : linkedHashLruCache.values()) { - if (value instanceof BackoffCacheEntry) { - refreshBackoffEntry((BackoffCacheEntry) value); - } - } - } - } - prevState = newState; - } - } - /** A header will be added when RLS server respond with additional header data. */ @VisibleForTesting static final Metadata.Key RLS_DATA_KEY = @@ -980,67 +1012,50 @@ final class RlsPicker extends SubchannelPicker { public PickResult pickSubchannel(PickSubchannelArgs args) { String serviceName = args.getMethodDescriptor().getServiceName(); String methodName = args.getMethodDescriptor().getBareMethodName(); - RouteLookupRequest request = + RlsProtoData.RouteLookupRequestKey lookupRequestKey = requestFactory.create(serviceName, methodName, args.getHeaders()); - final CachedRouteLookupResponse response = CachingRlsLbClient.this.get(request); - logger.log(ChannelLogLevel.DEBUG, - "Got route lookup cache entry for service={0}, method={1}, headers={2}:\n {3}", - new Object[]{serviceName, methodName, args.getHeaders(), response}); + final CachedRouteLookupResponse response = CachingRlsLbClient.this.get(lookupRequestKey); if (response.getHeaderData() != null && !response.getHeaderData().isEmpty()) { - logger.log(ChannelLogLevel.DEBUG, "Updating RLS metadata from the RLS response headers"); Metadata headers = args.getHeaders(); headers.discardAll(RLS_DATA_KEY); headers.put(RLS_DATA_KEY, response.getHeaderData()); } String defaultTarget = lbPolicyConfig.getRouteLookupConfig().defaultTarget(); - logger.log(ChannelLogLevel.DEBUG, "defaultTarget = {0}", defaultTarget); boolean hasFallback = defaultTarget != null && !defaultTarget.isEmpty(); if (response.hasData()) { - logger.log(ChannelLogLevel.DEBUG, "RLS response has data, proceed with selecting a picker"); ChildPolicyWrapper childPolicyWrapper = response.getChildPolicyWrapper(); SubchannelPicker picker = (childPolicyWrapper != null) ? childPolicyWrapper.getPicker() : null; if (picker == null) { - logger.log(ChannelLogLevel.DEBUG, - "Child policy wrapper didn't return a picker, returning PickResult with no results"); return PickResult.withNoResult(); } // Happy path - logger.log(ChannelLogLevel.DEBUG, "Returning PickResult"); PickResult pickResult = picker.pickSubchannel(args); if (pickResult.hasResult()) { helper.getMetricRecorder().addLongCounter(TARGET_PICKS_COUNTER, 1, Arrays.asList(helper.getChannelTarget(), lookupService, childPolicyWrapper.getTarget(), determineMetricsPickResult(pickResult)), - Collections.emptyList()); + Arrays.asList(determineCustomLabel(args))); } return pickResult; } else if (response.hasError()) { - logger.log(ChannelLogLevel.DEBUG, "RLS response has errors"); if (hasFallback) { - logger.log(ChannelLogLevel.DEBUG, "Using RLS fallback"); return useFallback(args); } - logger.log(ChannelLogLevel.DEBUG, "No RLS fallback, returning PickResult with an error"); helper.getMetricRecorder().addLongCounter(FAILED_PICKS_COUNTER, 1, - Arrays.asList(helper.getChannelTarget(), lookupService), Collections.emptyList()); + Arrays.asList(helper.getChannelTarget(), lookupService), + Arrays.asList(determineCustomLabel(args))); return PickResult.withError( convertRlsServerStatus(response.getStatus(), lbPolicyConfig.getRouteLookupConfig().lookupService())); } else { - logger.log(ChannelLogLevel.DEBUG, - "RLS response had no data, return a PickResult with no data"); return PickResult.withNoResult(); } } - private ChildPolicyWrapper fallbackChildPolicyWrapper; - /** Uses Subchannel connected to default target. */ private PickResult useFallback(PickSubchannelArgs args) { - // TODO(creamsoup) wait until lb is ready - startFallbackChildPolicy(); SubchannelPicker picker = fallbackChildPolicyWrapper.getPicker(); if (picker == null) { return PickResult.withNoResult(); @@ -1050,7 +1065,7 @@ private PickResult useFallback(PickSubchannelArgs args) { helper.getMetricRecorder().addLongCounter(DEFAULT_TARGET_PICKS_COUNTER, 1, Arrays.asList(helper.getChannelTarget(), lookupService, fallbackChildPolicyWrapper.getTarget(), determineMetricsPickResult(pickResult)), - Collections.emptyList()); + Arrays.asList(determineCustomLabel(args))); } return pickResult; } @@ -1065,23 +1080,13 @@ private String determineMetricsPickResult(PickResult pickResult) { } } - private void startFallbackChildPolicy() { - String defaultTarget = lbPolicyConfig.getRouteLookupConfig().defaultTarget(); - logger.log(ChannelLogLevel.DEBUG, "starting fallback to {0}", defaultTarget); - logger.log(ChannelLogLevel.DEBUG, "Acquiring lock to start fallback child policy"); - synchronized (lock) { - logger.log(ChannelLogLevel.DEBUG, "Acquired lock for starting fallback child policy"); - if (fallbackChildPolicyWrapper != null) { - return; - } - fallbackChildPolicyWrapper = refCountedChildPolicyWrapperFactory.createOrGet(defaultTarget); - } + private String determineCustomLabel(PickSubchannelArgs args) { + return args.getCallOptions().getOption(Grpc.CALL_OPTION_CUSTOM_LABEL); } // GuardedBy CachingRlsLbClient.lock void close() { synchronized (lock) { // Lock is already held, but ErrorProne can't tell - logger.log(ChannelLogLevel.DEBUG, "Closing RLS picker"); if (fallbackChildPolicyWrapper != null) { refCountedChildPolicyWrapperFactory.release(fallbackChildPolicyWrapper); } diff --git a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java index 4d6ceed9235..77ed080e654 100644 --- a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java +++ b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java @@ -31,6 +31,7 @@ import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status; import io.grpc.internal.ObjectPool; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; import io.grpc.rls.RlsProtoData.RouteLookupConfig; @@ -209,35 +210,45 @@ static final class RefCountedChildPolicyWrapperFactory { new HashMap<>(); private final ChildLoadBalancerHelperProvider childLbHelperProvider; - private final ChildLbStatusListener childLbStatusListener; private final ChildLoadBalancingPolicy childPolicy; - private final ResolvedAddressFactory childLbResolvedAddressFactory; + private ResolvedAddressFactory childLbResolvedAddressFactory; public RefCountedChildPolicyWrapperFactory( ChildLoadBalancingPolicy childPolicy, ResolvedAddressFactory childLbResolvedAddressFactory, - ChildLoadBalancerHelperProvider childLbHelperProvider, - ChildLbStatusListener childLbStatusListener) { + ChildLoadBalancerHelperProvider childLbHelperProvider) { this.childPolicy = checkNotNull(childPolicy, "childPolicy"); this.childLbResolvedAddressFactory = checkNotNull(childLbResolvedAddressFactory, "childLbResolvedAddressFactory"); this.childLbHelperProvider = checkNotNull(childLbHelperProvider, "childLbHelperProvider"); - this.childLbStatusListener = checkNotNull(childLbStatusListener, "childLbStatusListener"); } void init() { childLbHelperProvider.init(); } + Status acceptResolvedAddressFactory(ResolvedAddressFactory childLbResolvedAddressFactory) { + this.childLbResolvedAddressFactory = childLbResolvedAddressFactory; + Status status = Status.OK; + for (RefCountedChildPolicyWrapper wrapper : childPolicyMap.values()) { + Status newStatus = + wrapper.childPolicyWrapper.acceptResolvedAddressFactory(childLbResolvedAddressFactory); + if (!newStatus.isOk()) { + status = newStatus; + } + } + return status; + } + ChildPolicyWrapper createOrGet(String target) { // TODO(creamsoup) check if the target is valid or not RefCountedChildPolicyWrapper pooledChildPolicyWrapper = childPolicyMap.get(target); if (pooledChildPolicyWrapper == null) { ChildPolicyWrapper childPolicyWrapper = new ChildPolicyWrapper( - target, childPolicy, childLbResolvedAddressFactory, childLbHelperProvider, - childLbStatusListener); + target, childPolicy, childLbHelperProvider); pooledChildPolicyWrapper = RefCountedChildPolicyWrapper.of(childPolicyWrapper); childPolicyMap.put(target, pooledChildPolicyWrapper); + childPolicyWrapper.start(childLbResolvedAddressFactory); return pooledChildPolicyWrapper.getObject(); } else { ChildPolicyWrapper childPolicyWrapper = pooledChildPolicyWrapper.getObject(); @@ -277,32 +288,33 @@ static final class ChildPolicyWrapper { private final String target; private final ChildPolicyReportingHelper helper; private final LoadBalancer lb; + private final Object childLbConfig; private volatile SubchannelPicker picker; private ConnectivityState state; public ChildPolicyWrapper( String target, ChildLoadBalancingPolicy childPolicy, - final ResolvedAddressFactory childLbResolvedAddressFactory, - ChildLoadBalancerHelperProvider childLbHelperProvider, - ChildLbStatusListener childLbStatusListener) { + ChildLoadBalancerHelperProvider childLbHelperProvider) { this.target = target; - this.helper = - new ChildPolicyReportingHelper(childLbHelperProvider, childLbStatusListener); + this.helper = new ChildPolicyReportingHelper(childLbHelperProvider); LoadBalancerProvider lbProvider = childPolicy.getEffectiveLbProvider(); final ConfigOrError lbConfig = lbProvider .parseLoadBalancingPolicyConfig( childPolicy.getEffectiveChildPolicy(target)); this.lb = lbProvider.newLoadBalancer(helper); + this.childLbConfig = lbConfig.getConfig(); helper.getChannelLogger().log( - ChannelLogLevel.DEBUG, "RLS child lb created. config: {0}", lbConfig.getConfig()); + ChannelLogLevel.DEBUG, "RLS child lb created. config: {0}", childLbConfig); + } + + void start(ResolvedAddressFactory childLbResolvedAddressFactory) { helper.getSynchronizationContext().execute( new Runnable() { @Override public void run() { - if (!lb.acceptResolvedAddresses( - childLbResolvedAddressFactory.create(lbConfig.getConfig())).isOk()) { + if (!acceptResolvedAddressFactory(childLbResolvedAddressFactory).isOk()) { helper.refreshNameResolution(); } lb.requestConnection(); @@ -310,6 +322,11 @@ public void run() { }); } + Status acceptResolvedAddressFactory(ResolvedAddressFactory childLbResolvedAddressFactory) { + helper.getSynchronizationContext().throwIfNotInThisSynchronizationContext(); + return lb.acceptResolvedAddresses(childLbResolvedAddressFactory.create(childLbConfig)); + } + String getTarget() { return target; } @@ -366,14 +383,11 @@ public String toString() { final class ChildPolicyReportingHelper extends ForwardingLoadBalancerHelper { private final ChildLoadBalancerHelper delegate; - private final ChildLbStatusListener listener; ChildPolicyReportingHelper( - ChildLoadBalancerHelperProvider childHelperProvider, - ChildLbStatusListener listener) { + ChildLoadBalancerHelperProvider childHelperProvider) { checkNotNull(childHelperProvider, "childHelperProvider"); this.delegate = childHelperProvider.forTarget(getTarget()); - this.listener = checkNotNull(listener, "listener"); } @Override @@ -386,18 +400,10 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne picker = newPicker; state = newState; super.updateBalancingState(newState, newPicker); - listener.onStatusChanged(newState); } } } - /** Listener for child lb status change events. */ - interface ChildLbStatusListener { - - /** Notifies when child lb status changes. */ - void onStatusChanged(ConnectivityState newState); - } - private static final class RefCountedChildPolicyWrapper implements ObjectPool { diff --git a/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java b/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java index ba0575efa57..9a961759693 100644 --- a/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java +++ b/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java @@ -22,6 +22,7 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Ticker; +import com.google.errorprone.annotations.CheckReturnValue; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; @@ -29,7 +30,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; /** @@ -43,7 +43,8 @@ abstract class LinkedHashLruCache implements LruCache { private final LinkedHashMap delegate; private final Ticker ticker; - private final EvictionListener evictionListener; + @Nullable + private final EvictionListener evictionListener; private long estimatedSizeBytes; private long estimatedMaxSizeBytes; @@ -53,7 +54,7 @@ abstract class LinkedHashLruCache implements LruCache { final Ticker ticker) { checkState(estimatedMaxSizeBytes > 0, "max estimated cache size should be positive"); this.estimatedMaxSizeBytes = estimatedMaxSizeBytes; - this.evictionListener = new SizeHandlingEvictionListener(evictionListener); + this.evictionListener = evictionListener; this.ticker = checkNotNull(ticker, "ticker"); delegate = new LinkedHashMap( // rough estimate or minimum hashmap default @@ -135,7 +136,7 @@ public final V cache(K key, V value) { estimatedSizeBytes += size; existing = delegate.put(key, new SizedValue(size, value)); if (existing != null) { - evictionListener.onEviction(key, existing, EvictionType.REPLACED); + fireOnEviction(key, existing, EvictionType.REPLACED); } return existing == null ? null : existing.value; } @@ -174,7 +175,7 @@ private V invalidate(K key, EvictionType cause) { checkNotNull(cause, "cause"); SizedValue existing = delegate.remove(key); if (existing != null) { - evictionListener.onEviction(key, existing, cause); + fireOnEviction(key, existing, cause); } return existing == null ? null : existing.value; } @@ -185,7 +186,7 @@ public final void invalidateAll() { while (iterator.hasNext()) { Map.Entry entry = iterator.next(); if (entry.getValue() != null) { - evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.EXPLICIT); + fireOnEviction(entry.getKey(), entry.getValue(), EvictionType.EXPLICIT); } iterator.remove(); } @@ -215,14 +216,13 @@ public final List values() { protected final boolean fitToLimit() { boolean removedAnyUnexpired = false; if (estimatedSizeBytes <= estimatedMaxSizeBytes) { - // new size is larger no need to do cleanup return false; } // cleanup expired entries long now = ticker.read(); cleanupExpiredEntries(now); - // cleanup eldest entry until new size limit + // cleanup eldest entry until the size of all entries fits within the limit Iterator> lruIter = delegate.entrySet().iterator(); while (lruIter.hasNext() && estimatedMaxSizeBytes < this.estimatedSizeBytes) { Map.Entry entry = lruIter.next(); @@ -230,8 +230,8 @@ protected final boolean fitToLimit() { break; // Violates some constraint like minimum age so stop our cleanup } lruIter.remove(); - // eviction listener will update the estimatedSizeBytes - evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.SIZE); + // fireOnEviction will update the estimatedSizeBytes + fireOnEviction(entry.getKey(), entry.getValue(), EvictionType.SIZE); removedAnyUnexpired = true; } return removedAnyUnexpired; @@ -270,7 +270,7 @@ private boolean cleanupExpiredEntries(int maxExpiredEntries, long now) { Map.Entry entry = lruIter.next(); if (isExpired(entry.getKey(), entry.getValue().value, now)) { lruIter.remove(); - evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.EXPIRED); + fireOnEviction(entry.getKey(), entry.getValue(), EvictionType.EXPIRED); removedAny = true; maxExpiredEntries--; } @@ -283,21 +283,10 @@ public final void close() { invalidateAll(); } - /** A {@link EvictionListener} keeps track of size. */ - private final class SizeHandlingEvictionListener implements EvictionListener { - - private final EvictionListener delegate; - - SizeHandlingEvictionListener(@Nullable EvictionListener delegate) { - this.delegate = delegate; - } - - @Override - public void onEviction(K key, SizedValue value, EvictionType cause) { - estimatedSizeBytes -= value.size; - if (delegate != null) { - delegate.onEviction(key, value.value, cause); - } + private void fireOnEviction(K key, SizedValue value, EvictionType cause) { + estimatedSizeBytes -= value.size; + if (evictionListener != null) { + evictionListener.onEviction(key, value.value, cause); } } diff --git a/rls/src/main/java/io/grpc/rls/LruCache.java b/rls/src/main/java/io/grpc/rls/LruCache.java index 1ad5a958289..8fc4ae98472 100644 --- a/rls/src/main/java/io/grpc/rls/LruCache.java +++ b/rls/src/main/java/io/grpc/rls/LruCache.java @@ -16,7 +16,7 @@ package io.grpc.rls; -import javax.annotation.CheckReturnValue; +import com.google.errorprone.annotations.CheckReturnValue; import javax.annotation.Nullable; /** An LruCache is a cache with least recently used eviction. */ diff --git a/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java b/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java index d1e537f1482..848199f50a8 100644 --- a/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java +++ b/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java @@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.MoreObjects; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityState; @@ -50,12 +49,11 @@ final class RlsLoadBalancer extends LoadBalancer { @Override public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - logger.log(ChannelLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); LbPolicyConfiguration lbPolicyConfiguration = (LbPolicyConfiguration) resolvedAddresses.getLoadBalancingPolicyConfig(); checkNotNull(lbPolicyConfiguration, "Missing RLS LB config"); if (!lbPolicyConfiguration.equals(this.lbPolicyConfiguration)) { - logger.log(ChannelLogLevel.DEBUG, "A new RLS LB config received"); + logger.log(ChannelLogLevel.DEBUG, "A new RLS LB config received: {0}", lbPolicyConfiguration); boolean needToConnect = this.lbPolicyConfiguration == null || !this.lbPolicyConfiguration.getRouteLookupConfig().lookupService().equals( lbPolicyConfiguration.getRouteLookupConfig().lookupService()); @@ -80,50 +78,32 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { // not required. this.lbPolicyConfiguration = lbPolicyConfiguration; } - logger.log(ChannelLogLevel.DEBUG, "RLS LB accepted resolved addresses successfully"); - return Status.OK; + return routeLookupClient.acceptResolvedAddressFactory( + new ChildLbResolvedAddressFactory( + resolvedAddresses.getAddresses(), resolvedAddresses.getAttributes())); } @Override public void requestConnection() { - logger.log(ChannelLogLevel.DEBUG, "connection requested from RLS LB"); if (routeLookupClient != null) { - logger.log(ChannelLogLevel.DEBUG, "requesting a connection from the routeLookupClient"); routeLookupClient.requestConnection(); } } @Override public void handleNameResolutionError(final Status error) { - logger.log(ChannelLogLevel.DEBUG, "Received resolution error: {0}", error); - class ErrorPicker extends SubchannelPicker { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withError(error); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("error", error) - .toString(); - } - } - if (routeLookupClient != null) { logger.log(ChannelLogLevel.DEBUG, "closing the routeLookupClient on a name resolution error"); routeLookupClient.close(); routeLookupClient = null; lbPolicyConfiguration = null; } - logger.log(ChannelLogLevel.DEBUG, - "Updating balancing state to TRANSIENT_FAILURE with an error picker"); - helper.updateBalancingState(ConnectivityState.TRANSIENT_FAILURE, new ErrorPicker()); + helper.updateBalancingState( + ConnectivityState.TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); } @Override public void shutdown() { - logger.log(ChannelLogLevel.DEBUG, "Rls lb shutdown"); if (routeLookupClient != null) { logger.log(ChannelLogLevel.DEBUG, "closing the routeLookupClient because of RLS LB shutdown"); routeLookupClient.close(); diff --git a/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java b/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java index cd164f5e2a7..70f9fb4d891 100644 --- a/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java +++ b/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java @@ -64,7 +64,9 @@ static final class RouteLookupRequestConverter @Override protected RlsProtoData.RouteLookupRequest doForward(RouteLookupRequest routeLookupRequest) { return RlsProtoData.RouteLookupRequest.create( - ImmutableMap.copyOf(routeLookupRequest.getKeyMapMap())); + ImmutableMap.copyOf(routeLookupRequest.getKeyMapMap()), + RlsProtoData.RouteLookupRequest.Reason.valueOf(routeLookupRequest.getReason().name()) + ); } @Override @@ -72,6 +74,7 @@ protected RouteLookupRequest doBackward(RlsProtoData.RouteLookupRequest routeLoo return RouteLookupRequest.newBuilder() .setTargetType("grpc") + .setReason(RouteLookupRequest.Reason.valueOf(routeLookupRequest.reason().name())) .putAllKeyMap(routeLookupRequest.keyMap()) .build(); } @@ -152,10 +155,15 @@ protected RouteLookupConfig doForward(Map json) { checkArgument(staleAge == null, "to specify staleAge, must have maxAge"); maxAge = MAX_AGE_NANOS; } - if (staleAge == null) { + // If staleAge is not set, clamp maxAge to <= 5. + if (staleAge == null && maxAge > MAX_AGE_NANOS) { + maxAge = MAX_AGE_NANOS; + } + // Clamp staleAge to <= 5 + if (staleAge == null || staleAge > MAX_AGE_NANOS) { staleAge = MAX_AGE_NANOS; } - maxAge = Math.min(maxAge, MAX_AGE_NANOS); + // Ignore staleAge if greater than maxAge. staleAge = Math.min(staleAge, maxAge); long cacheSize = orDefault(JsonUtil.getNumberAsLong(json, "cacheSizeBytes"), MAX_CACHE_SIZE); checkArgument(cacheSize > 0, "cacheSize must be positive"); diff --git a/rls/src/main/java/io/grpc/rls/RlsProtoData.java b/rls/src/main/java/io/grpc/rls/RlsProtoData.java index 49f32c6b6e3..39c404870f9 100644 --- a/rls/src/main/java/io/grpc/rls/RlsProtoData.java +++ b/rls/src/main/java/io/grpc/rls/RlsProtoData.java @@ -27,16 +27,42 @@ final class RlsProtoData { private RlsProtoData() {} + /** A key object for the Rls route lookup data cache. */ + @AutoValue + @Immutable + abstract static class RouteLookupRequestKey { + + /** Returns a map of key values extracted via key builders for the gRPC or HTTP request. */ + abstract ImmutableMap keyMap(); + + static RouteLookupRequestKey create(ImmutableMap keyMap) { + return new AutoValue_RlsProtoData_RouteLookupRequestKey(keyMap); + } + } + /** A request object sent to route lookup service. */ @AutoValue @Immutable abstract static class RouteLookupRequest { + /** Names should match those in {@link io.grpc.lookup.v1.RouteLookupRequest.Reason}. */ + enum Reason { + /** Unused. */ + REASON_UNKNOWN, + /** No data available in local cache. */ + REASON_MISS, + /** Data in local cache is stale. */ + REASON_STALE; + } + + /** Reason for making this request. */ + abstract Reason reason(); + /** Returns a map of key values extracted via key builders for the gRPC or HTTP request. */ abstract ImmutableMap keyMap(); - static RouteLookupRequest create(ImmutableMap keyMap) { - return new AutoValue_RlsProtoData_RouteLookupRequest(keyMap); + static RouteLookupRequest create(ImmutableMap keyMap, Reason reason) { + return new AutoValue_RlsProtoData_RouteLookupRequest(reason, keyMap); } } diff --git a/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java b/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java index a6ca0137ff1..1fed78f4df3 100644 --- a/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java +++ b/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java @@ -20,20 +20,20 @@ import com.google.common.base.MoreObjects; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Metadata; import io.grpc.rls.RlsProtoData.ExtraKeys; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; import io.grpc.rls.RlsProtoData.RouteLookupConfig; -import io.grpc.rls.RlsProtoData.RouteLookupRequest; +import io.grpc.rls.RlsProtoData.RouteLookupRequestKey; import java.util.HashMap; import java.util.List; import java.util.Map; -import javax.annotation.CheckReturnValue; /** - * A RlsRequestFactory creates {@link RouteLookupRequest} using key builder map from {@link + * A RlsRequestFactory creates {@link RouteLookupRequestKey} using key builder map from {@link * RouteLookupConfig}. */ final class RlsRequestFactory { @@ -61,9 +61,9 @@ private static Map createKeyBuilderTable( return table; } - /** Creates a {@link RouteLookupRequest} for given request's metadata. */ + /** Creates a {@link RouteLookupRequestKey} for the given request lookup metadata. */ @CheckReturnValue - RouteLookupRequest create(String service, String method, Metadata metadata) { + RouteLookupRequestKey create(String service, String method, Metadata metadata) { checkNotNull(service, "service"); checkNotNull(method, "method"); String path = "/" + service + "/" + method; @@ -73,7 +73,7 @@ RouteLookupRequest create(String service, String method, Metadata metadata) { grpcKeyBuilder = keyBuilderTable.get("/" + service + "/*"); } if (grpcKeyBuilder == null) { - return RouteLookupRequest.create(ImmutableMap.of()); + return RouteLookupRequestKey.create(ImmutableMap.of()); } ImmutableMap.Builder rlsRequestHeaders = createRequestHeaders(metadata, grpcKeyBuilder.headers()); @@ -89,7 +89,7 @@ RouteLookupRequest create(String service, String method, Metadata metadata) { rlsRequestHeaders.put(extraKeys.method(), method); } rlsRequestHeaders.putAll(constantKeys); - return RouteLookupRequest.create(rlsRequestHeaders.buildOrThrow()); + return RouteLookupRequestKey.create(rlsRequestHeaders.buildOrThrow()); } private ImmutableMap.Builder createRequestHeaders( diff --git a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java index 7c5df2c96b3..b349aecdbf3 100644 --- a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java +++ b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java @@ -59,6 +59,7 @@ import io.grpc.MetricRecorder.BatchRecorder; import io.grpc.MetricRecorder.Registration; import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Server; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; @@ -66,7 +67,10 @@ import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.BackoffPolicy; import io.grpc.internal.FakeClock; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ObjectPool; import io.grpc.internal.PickSubchannelArgsImpl; +import io.grpc.internal.SharedResourcePool; import io.grpc.lookup.v1.RouteLookupServiceGrpc; import io.grpc.rls.CachingRlsLbClient.CacheEntry; import io.grpc.rls.CachingRlsLbClient.CachedRouteLookupResponse; @@ -96,10 +100,13 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.Nonnull; import org.junit.After; import org.junit.Before; @@ -128,7 +135,7 @@ public class CachingRlsLbClientTest { public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); @Mock - private EvictionListener evictionListener; + private EvictionListener evictionListener; @Mock private SocketAddress socketAddress; @Mock @@ -160,8 +167,9 @@ public void uncaughtException(Thread t, Throwable e) { fakeClock.getScheduledExecutorService()); private final ChildLoadBalancingPolicy childLbPolicy = new ChildLoadBalancingPolicy("target", Collections.emptyMap(), lbProvider); + private final FakeHelper fakeHelper = new FakeHelper(); private final Helper helper = - mock(Helper.class, delegatesTo(new FakeHelper())); + mock(Helper.class, delegatesTo(fakeHelper)); private final FakeThrottler fakeThrottler = new FakeThrottler(); private final LbPolicyConfiguration lbPolicyConfiguration = new LbPolicyConfiguration(ROUTE_LOOKUP_CONFIG, null, childLbPolicy); @@ -191,21 +199,23 @@ public void setUpMockMetricRecorder() { @After public void tearDown() throws Exception { - rlsLbClient.close(); + if (rlsLbClient != null) { + rlsLbClient.close(); + } assertWithMessage( "On client shut down, RlsLoadBalancer must shut down with all its child loadbalancers.") .that(lbProvider.loadBalancers).isEmpty(); } private CachedRouteLookupResponse getInSyncContext( - final RouteLookupRequest request) + final RlsProtoData.RouteLookupRequestKey routeLookupRequestKey) throws ExecutionException, InterruptedException, TimeoutException { final SettableFuture responseSettableFuture = SettableFuture.create(); syncContext.execute(new Runnable() { @Override public void run() { - responseSettableFuture.set(rlsLbClient.get(request)); + responseSettableFuture.set(rlsLbClient.get(routeLookupRequestKey)); } }); return responseSettableFuture.get(5, TimeUnit.SECONDS); @@ -215,48 +225,53 @@ public void run() { public void get_noError_lifeCycle() throws Exception { setUpRlsLbClient(); InOrder inOrder = inOrder(evictionListener); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create(ImmutableList.of("target"), "header"))); // initial request - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); // server response fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); // cache hit for staled entry fakeClock.forwardTime(ROUTE_LOOKUP_CONFIG.staleAgeInNanos(), TimeUnit.NANOSECONDS); - resp = getInSyncContext(routeLookupRequest); + rlsServerImpl.routeLookupReason = null; + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); // async refresh finishes fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); inOrder .verify(evictionListener) - .onEviction(eq(routeLookupRequest), any(CacheEntry.class), eq(EvictionType.REPLACED)); + .onEviction(eq(routeLookupRequestKey), any(CacheEntry.class), eq(EvictionType.REPLACED)); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); + assertThat(rlsServerImpl.routeLookupReason).isEqualTo( + io.grpc.lookup.v1.RouteLookupRequest.Reason.REASON_STALE); assertThat(resp.hasData()).isTrue(); // existing cache expired fakeClock.forwardTime(ROUTE_LOOKUP_CONFIG.maxAgeInNanos(), TimeUnit.NANOSECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); inOrder .verify(evictionListener) - .onEviction(eq(routeLookupRequest), any(CacheEntry.class), eq(EvictionType.EXPIRED)); + .onEviction(eq(routeLookupRequestKey), any(CacheEntry.class), eq(EvictionType.EXPIRED)); inOrder.verifyNoMoreInteractions(); } @@ -285,99 +300,275 @@ public void rls_withCustomRlsChannelServiceConfig() throws Exception { .setThrottler(fakeThrottler) .setTicker(fakeClock.getTicker()) .build(); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create(ImmutableList.of("target"), "header"))); + rlsServerImpl.routeLookupReason = null; // initial request - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); // server response fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); + assertThat(rlsServerImpl.routeLookupReason).isEqualTo( + io.grpc.lookup.v1.RouteLookupRequest.Reason.REASON_MISS); assertThat(rlsChannelOverriddenAuthority).isEqualTo("bigtable.googleapis.com:443"); assertThat(rlsChannelServiceConfig).isEqualTo(routeLookupChannelServiceConfig); } @Test - public void get_throttledAndRecover() throws Exception { + public void backoffTimerEnd_updatesPicker() throws Exception { setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + InOrder inOrder = inOrder(helper); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create(ImmutableList.of("target"), "header"))); fakeThrottler.nextResult = true; fakeBackoffProvider.nextPolicy = createBackoffPolicy(10, TimeUnit.MILLISECONDS); - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); - + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasError()).isTrue(); fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); - // initially backed off entry is backed off again - verify(evictionListener) - .onEviction(eq(routeLookupRequest), any(CacheEntry.class), eq(EvictionType.EXPLICIT)); + // Assert that Rls LB policy picker was updated which picks the fallback target + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); + ArgumentCaptor stateCaptor = + ArgumentCaptor.forClass(ConnectivityState.class); - resp = getInSyncContext(routeLookupRequest); + inOrder.verify(helper, times(3)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + assertThat(new HashSet<>(pickerCaptor.getAllValues())).hasSize(1); + assertThat(stateCaptor.getAllValues()) + .containsExactly(ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.CONNECTING, + ConnectivityState.CONNECTING); + Metadata headers = new Metadata(); + PickResult pickResult = getPickResultForCreate(pickerCaptor, headers); + assertThat(pickResult.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(pickResult.getStatus().getDescription()).isEqualTo("fallback not available"); + } + @Test + public void get_throttledTwice_usesSameBackoffpolicy() throws Exception { + setUpRlsLbClient(); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + rlsServerImpl.setLookupTable( + ImmutableMap.of( + routeLookupRequestKey, + RouteLookupResponse.create(ImmutableList.of("target"), "header"))); + + fakeThrottler.nextResult = true; + fakeBackoffProvider.nextPolicy = createBackoffPolicy(10, TimeUnit.MILLISECONDS); + + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); + + assertThat(resp.hasError()).isTrue(); + + fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); + + // Assert that the same backoff policy is still in effect for the cache entry. + // The below provider should not get used, so the back off time will still be set to 10ms. + fakeBackoffProvider.nextPolicy = createBackoffPolicy(20, TimeUnit.MILLISECONDS); + // let it be throttled again + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasError()).isTrue(); - // let it pass throttler + fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); + + // Backoff entry's backoff timer has gone off, so next rpc should not be backed off. fakeThrottler.nextResult = false; + resp = getInSyncContext(routeLookupRequestKey); + assertThat(resp.isPending()).isTrue(); + + rlsServerImpl.routeLookupReason = null; + // server responses + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + assertThat(rlsServerImpl.routeLookupReason).isEqualTo( + io.grpc.lookup.v1.RouteLookupRequest.Reason.REASON_MISS); + } + + @Test + public void get_errorResponseTwice_usesSameBackoffPolicy() throws Exception { + setUpRlsLbClient(); + RlsProtoData.RouteLookupRequestKey invalidRouteLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create(ImmutableMap.of()); + CachedRouteLookupResponse resp = getInSyncContext(invalidRouteLookupRequestKey); + assertThat(resp.isPending()).isTrue(); + fakeBackoffProvider.nextPolicy = createBackoffPolicy(10, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + assertThat(rlsServerImpl.routeLookupReason).isEqualTo( + io.grpc.lookup.v1.RouteLookupRequest.Reason.REASON_MISS); + + resp = getInSyncContext(invalidRouteLookupRequestKey); + assertThat(resp.hasError()).isTrue(); + + // Backoff time expiry fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); + resp = getInSyncContext(invalidRouteLookupRequestKey); + assertThat(resp.isPending()).isTrue(); + // Assert that the same backoff policy is still in effect for the cache entry. + // The below provider should not get used, so the back off time will still be set to 10ms. + fakeBackoffProvider.nextPolicy = createBackoffPolicy(20, TimeUnit.MILLISECONDS); + // Gets error again and backed off again + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(invalidRouteLookupRequestKey); + assertThat(resp.hasError()).isTrue(); + // Backoff time expiry + fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); + resp = getInSyncContext(invalidRouteLookupRequestKey); assertThat(resp.isPending()).isTrue(); + rlsServerImpl.routeLookupReason = null; // server responses fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + assertThat(rlsServerImpl.routeLookupReason).isEqualTo( + io.grpc.lookup.v1.RouteLookupRequest.Reason.REASON_MISS); + } - resp = getInSyncContext(routeLookupRequest); + @Test + public void controlPlaneTransientToReady_backOffEntriesRemovedAndPickerUpdated() + throws Exception { + setUpRlsLbClient(); + InOrder inOrder = inOrder(helper); + final ConnectivityState[] rlsChannelState = new ConnectivityState[1]; + Runnable channelStateListener = new Runnable() { + @Override + public void run() { + rlsChannelState[0] = fakeHelper.oobChannel.getState(false); + fakeHelper.oobChannel.notifyWhenStateChanged(rlsChannelState[0], this); + synchronized (this) { + notify(); + } + } + }; + fakeHelper.oobChannel.notifyWhenStateChanged(fakeHelper.oobChannel.getState(false), + channelStateListener); + + fakeHelper.server.shutdown(); + // Channel goes to IDLE state from the shutdown listener handling. + try { + if (!fakeHelper.server.awaitTermination(10, TimeUnit.SECONDS)) { + fakeHelper.server.shutdownNow(); // Forceful shutdown if graceful timeout expires + } + } catch (InterruptedException e) { + fakeHelper.server.shutdownNow(); + } + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + // Rls channel will go to TRANSIENT_FAILURE (connection back-off). + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); + assertThat(resp.isPending()).isTrue(); + assertThat(rlsChannelState[0]).isEqualTo(ConnectivityState.TRANSIENT_FAILURE); + // Throttle the next rpc call. + fakeThrottler.nextResult = true; + fakeBackoffProvider.nextPolicy = createBackoffPolicy(10, TimeUnit.MILLISECONDS); - assertThat(resp.hasData()).isTrue(); + // Cause two cache misses by using new request keys. This will create back-off Rls cache + // entries. RLS control plane state transitioning to READY should reset both back-offs but + // update picker only once. + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey2 = + RlsProtoData.RouteLookupRequestKey.create(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo2", "method-key", "bar")); + resp = getInSyncContext(routeLookupRequestKey2); + assertThat(resp.hasError()).isTrue(); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey3 = + RlsProtoData.RouteLookupRequestKey.create(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo3", "method-key", "bar")); + resp = getInSyncContext(routeLookupRequestKey3); + assertThat(resp.hasError()).isTrue(); + + fakeHelper.createServerAndRegister("service1"); + // Wait for Rls control plane channel back-off expiry and its moving to READY + synchronized (channelStateListener) { + channelStateListener.wait(2000); + } + assertThat(rlsChannelState[0]).isEqualTo(ConnectivityState.READY); + final ObjectPool defaultExecutorPool = + SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR); + AtomicBoolean isSuccess = new AtomicBoolean(false); + ((ExecutorService) defaultExecutorPool.getObject()).submit(() -> { + // Assert that Rls LB policy picker was updated which picks the fallback target + ArgumentCaptor pickerCaptor = + ArgumentCaptor.forClass(SubchannelPicker.class); + ArgumentCaptor stateCaptor = + ArgumentCaptor.forClass(ConnectivityState.class); + + inOrder.verify(helper, times(4)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + assertThat(new HashSet<>(pickerCaptor.getAllValues())).hasSize(1); + assertThat(stateCaptor.getAllValues()) + .containsExactly(ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.CONNECTING, + ConnectivityState.CONNECTING, ConnectivityState.CONNECTING); + Metadata headers = new Metadata(); + PickResult pickResult = getPickResultForCreate(pickerCaptor, headers); + assertThat(pickResult.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(pickResult.getStatus().getDescription()).isEqualTo("fallback not available"); + isSuccess.set(true); + }).get(); + assertThat(isSuccess.get()).isTrue(); + + fakeThrottler.nextResult = false; + // Rpcs are not backed off now. + assertThat(getInSyncContext(routeLookupRequestKey2).isPending()).isTrue(); + assertThat(getInSyncContext(routeLookupRequestKey3).isPending()).isTrue(); } @Test public void get_updatesLbState() throws Exception { setUpRlsLbClient(); InOrder inOrder = inOrder(helper); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "service1", "method-key", "create")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "service1", + "method-key", "create")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create( ImmutableList.of("primary.cloudbigtable.googleapis.com"), "header-rls-data-value"))); // valid channel - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); ArgumentCaptor stateCaptor = ArgumentCaptor.forClass(ConnectivityState.class); - inOrder.verify(helper, times(2)) + inOrder.verify(helper, times(3)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); assertThat(new HashSet<>(pickerCaptor.getAllValues())).hasSize(1); + // TRANSIENT_FAILURE is because the test setup pretends fallback is not available. assertThat(stateCaptor.getAllValues()) - .containsExactly(ConnectivityState.CONNECTING, ConnectivityState.READY); + .containsExactly(ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.CONNECTING, + ConnectivityState.READY); Metadata headers = new Metadata(); PickResult pickResult = getPickResultForCreate(pickerCaptor, headers); assertThat(pickResult.getStatus().isOk()).isTrue(); @@ -389,13 +580,13 @@ public void get_updatesLbState() throws Exception { // move backoff further back to only test error behavior fakeBackoffProvider.nextPolicy = createBackoffPolicy(100, TimeUnit.MILLISECONDS); // try to get invalid - RouteLookupRequest invalidRouteLookupRequest = - RouteLookupRequest.create(ImmutableMap.of()); - CachedRouteLookupResponse errorResp = getInSyncContext(invalidRouteLookupRequest); + RlsProtoData.RouteLookupRequestKey invalidRouteLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create(ImmutableMap.of()); + CachedRouteLookupResponse errorResp = getInSyncContext(invalidRouteLookupRequestKey); assertThat(errorResp.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - errorResp = getInSyncContext(invalidRouteLookupRequest); + errorResp = getInSyncContext(invalidRouteLookupRequestKey); assertThat(errorResp.hasError()).isTrue(); // Channel is still READY because the subchannel for method /service1/create is still READY. @@ -419,27 +610,30 @@ public void get_updatesLbState() throws Exception { @Test public void timeout_not_changing_picked_subchannel() throws Exception { setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "service1", "method-key", "create")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "service1", + "method-key", "create")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create( ImmutableList.of("primary.cloudbigtable.googleapis.com", "target2", "target3"), "header-rls-data-value"))); // valid channel - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isFalse(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); ArgumentCaptor stateCaptor = ArgumentCaptor.forClass(ConnectivityState.class); - verify(helper, times(4)).updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + verify(helper, times(5)).updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); Metadata headers = new Metadata(); PickResult pickResult = getPickResultForCreate(pickerCaptor, headers); @@ -489,27 +683,30 @@ public void get_withAdaptiveThrottler() throws Exception { .setTicker(fakeClock.getTicker()) .build(); InOrder inOrder = inOrder(helper); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "service1", "method-key", "create")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "service1", + "method-key", "create")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create( ImmutableList.of("primary.cloudbigtable.googleapis.com"), "header-rls-data-value"))); // valid channel - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); ArgumentCaptor stateCaptor = ArgumentCaptor.forClass(ConnectivityState.class); - inOrder.verify(helper, times(2)) + inOrder.verify(helper, times(3)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); Metadata headers = new Metadata(); @@ -520,13 +717,13 @@ public void get_withAdaptiveThrottler() throws Exception { // move backoff further back to only test error behavior fakeBackoffProvider.nextPolicy = createBackoffPolicy(100, TimeUnit.MILLISECONDS); // try to get invalid - RouteLookupRequest invalidRouteLookupRequest = - RouteLookupRequest.create(ImmutableMap.of()); - CachedRouteLookupResponse errorResp = getInSyncContext(invalidRouteLookupRequest); + RlsProtoData.RouteLookupRequestKey invalidRouteLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create(ImmutableMap.of()); + CachedRouteLookupResponse errorResp = getInSyncContext(invalidRouteLookupRequestKey); assertThat(errorResp.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - errorResp = getInSyncContext(invalidRouteLookupRequest); + errorResp = getInSyncContext(invalidRouteLookupRequestKey); assertThat(errorResp.hasError()).isTrue(); // Channel is still READY because the subchannel for method /service1/create is still READY. @@ -556,22 +753,26 @@ private PickSubchannelArgsImpl getInvalidArgs(Metadata headers) { @Test public void get_childPolicyWrapper_reusedForSameTarget() throws Exception { setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); - RouteLookupRequest routeLookupRequest2 = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "baz")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey2 = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "baz")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create(ImmutableList.of("target"), "header"), - routeLookupRequest2, + routeLookupRequestKey2, RouteLookupResponse.create(ImmutableList.of("target"), "header2"))); - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); assertThat(resp.getHeaderData()).isEqualTo("header"); @@ -581,11 +782,11 @@ public void get_childPolicyWrapper_reusedForSameTarget() throws Exception { assertThat(childPolicyWrapper.getPicker()).isNotInstanceOf(RlsPicker.class); // request2 has same target, it should reuse childPolicyWrapper - CachedRouteLookupResponse resp2 = getInSyncContext(routeLookupRequest2); + CachedRouteLookupResponse resp2 = getInSyncContext(routeLookupRequestKey2); assertThat(resp2.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp2 = getInSyncContext(routeLookupRequest2); + resp2 = getInSyncContext(routeLookupRequestKey2); assertThat(resp2.hasData()).isTrue(); assertThat(resp2.getHeaderData()).isEqualTo("header2"); assertThat(resp2.getChildPolicyWrapper()).isEqualTo(resp.getChildPolicyWrapper()); @@ -594,20 +795,22 @@ public void get_childPolicyWrapper_reusedForSameTarget() throws Exception { @Test public void get_childPolicyWrapper_multiTarget() throws Exception { setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create( ImmutableList.of("target1", "target2", "target3"), "header"))); - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); List policyWrappers = new ArrayList<>(); @@ -676,14 +879,15 @@ public void metricGauges() throws ExecutionException, InterruptedException, Time .recordLongGauge(argThat(new LongGaugeInstrumentArgumentMatcher("grpc.lb.rls.cache_size")), eq(0L), any(), any()); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create( - ImmutableMap.of("server", "bigtable.googleapis.com", "service-key", "foo", "method-key", - "bar")); - rlsServerImpl.setLookupTable(ImmutableMap.of(routeLookupRequest, + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of("server", "bigtable.googleapis.com", "service-key", "foo", "method-key", + "bar")); + rlsServerImpl.setLookupTable(ImmutableMap.of(routeLookupRequestKey, RouteLookupResponse.create(ImmutableList.of("target"), "header"))); // Make a request that will populate the cache with an entry - getInSyncContext(routeLookupRequest); + getInSyncContext(routeLookupRequestKey); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); // Gauge values should reflect the new cache entry. @@ -699,6 +903,7 @@ public void metricGauges() throws ExecutionException, InterruptedException, Time // Shutdown rlsLbClient.close(); + rlsLbClient = null; verify(mockGaugeRegistration).close(); } @@ -820,14 +1025,9 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { @Override public void handleNameResolutionError(final Status error) { - class ErrorPicker extends SubchannelPicker { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withError(error); - } - } - - helper.updateBalancingState(ConnectivityState.TRANSIENT_FAILURE, new ErrorPicker()); + helper.updateBalancingState( + ConnectivityState.TRANSIENT_FAILURE, + new FixedResultPicker(PickResult.withError(error))); } @Override @@ -852,7 +1052,9 @@ private static final class StaticFixedDelayRlsServerImpl private final long responseDelayNano; private final ScheduledExecutorService scheduledExecutorService; - private Map lookupTable = ImmutableMap.of(); + private Map lookupTable = + ImmutableMap.of(); + io.grpc.lookup.v1.RouteLookupRequest.Reason routeLookupReason; public StaticFixedDelayRlsServerImpl( long responseDelayNano, ScheduledExecutorService scheduledExecutorService) { @@ -862,7 +1064,8 @@ public StaticFixedDelayRlsServerImpl( checkNotNull(scheduledExecutorService, "scheduledExecutorService"); } - private void setLookupTable(Map lookupTable) { + private void setLookupTable(Map lookupTable) { this.lookupTable = checkNotNull(lookupTable, "lookupTable"); } @@ -874,8 +1077,11 @@ public void routeLookup(final io.grpc.lookup.v1.RouteLookupRequest request, new Runnable() { @Override public void run() { + routeLookupReason = request.getReason(); RouteLookupResponse response = - lookupTable.get(REQUEST_CONVERTER.convert(request)); + lookupTable.get( + RlsProtoData.RouteLookupRequestKey.create( + REQUEST_CONVERTER.convert(request).keyMap())); if (response == null) { responseObserver.onError(new RuntimeException("not found")); } else { @@ -889,16 +1095,23 @@ public void run() { private final class FakeHelper extends Helper { + Server server; + ManagedChannel oobChannel; + + void createServerAndRegister(String target) throws IOException { + server = InProcessServerBuilder.forName(target) + .addService(rlsServerImpl) + .directExecutor() + .build() + .start(); + grpcCleanupRule.register(server); + } + @Override public ManagedChannelBuilder createResolvingOobChannelBuilder( String target, ChannelCredentials creds) { try { - grpcCleanupRule.register( - InProcessServerBuilder.forName(target) - .addService(rlsServerImpl) - .directExecutor() - .build() - .start()); + createServerAndRegister(target); } catch (IOException e) { throw new RuntimeException("cannot create server: " + target, e); } @@ -914,7 +1127,8 @@ protected ManagedChannelBuilder delegate() { @Override public ManagedChannel build() { - return grpcCleanupRule.register(super.build()); + oobChannel = super.build(); + return grpcCleanupRule.register(oobChannel); } @Override @@ -943,7 +1157,6 @@ public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String author @Override public void updateBalancingState( @Nonnull ConnectivityState newState, @Nonnull SubchannelPicker newPicker) { - // no-op } @Override diff --git a/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java b/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java index d6025d5bad4..de41d0488fc 100644 --- a/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java +++ b/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java @@ -21,6 +21,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -39,14 +40,13 @@ import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; -import io.grpc.rls.LbPolicyConfiguration.ChildLbStatusListener; import io.grpc.rls.LbPolicyConfiguration.ChildLoadBalancingPolicy; import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper; import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper.ChildPolicyReportingHelper; import io.grpc.rls.LbPolicyConfiguration.InvalidChildPolicyConfigException; import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory; -import java.lang.Thread.UncaughtExceptionHandler; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -61,7 +61,9 @@ public class LbPolicyConfigurationTest { private final LoadBalancer lb = mock(LoadBalancer.class); private final SubchannelStateManager subchannelStateManager = new SubchannelStateManagerImpl(); private final SubchannelPicker picker = mock(SubchannelPicker.class); - private final ChildLbStatusListener childLbStatusListener = mock(ChildLbStatusListener.class); + private final SynchronizationContext syncContext = new SynchronizationContext((t, e) -> { + throw new AssertionError(e); + }); private final ResolvedAddressFactory resolvedAddressFactory = new ResolvedAddressFactory() { @Override @@ -78,21 +80,12 @@ public ResolvedAddresses create(Object childLbConfig) { ImmutableMap.of("foo", "bar"), lbProvider), resolvedAddressFactory, - new ChildLoadBalancerHelperProvider(helper, subchannelStateManager, picker), - childLbStatusListener); + new ChildLoadBalancerHelperProvider(helper, subchannelStateManager, picker)); @Before public void setUp() { doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger(); - doReturn( - new SynchronizationContext( - new UncaughtExceptionHandler() { - @Override - public void uncaughtException(Thread t, Throwable e) { - throw new AssertionError(e); - } - })) - .when(helper).getSynchronizationContext(); + doReturn(syncContext).when(helper).getSynchronizationContext(); doReturn(lb).when(lbProvider).newLoadBalancer(any(Helper.class)); doReturn(ConfigOrError.fromConfig(new Object())) .when(lbProvider).parseLoadBalancingPolicyConfig(ArgumentMatchers.>any()); @@ -185,9 +178,26 @@ public void updateBalancingState_triggersListener() { childPolicyReportingHelper.updateBalancingState(ConnectivityState.READY, childPicker); - verify(childLbStatusListener).onStatusChanged(ConnectivityState.READY); assertThat(childPolicyWrapper.getPicker()).isEqualTo(childPicker); // picker governs childPickers will be reported to parent LB verify(helper).updateBalancingState(ConnectivityState.READY, picker); } + + @Test + public void refCountedGetOrCreate_addsChildBeforeConfiguringChild() { + AtomicBoolean calledAlready = new AtomicBoolean(); + when(lb.acceptResolvedAddresses(any(ResolvedAddresses.class))).thenAnswer(i -> { + if (!calledAlready.get()) { + calledAlready.set(true); + // Should end up calling this function again, as this child should already be added to the + // list of children. In practice, this can be caused by CDS is_dynamic=true starting a watch + // when XdsClient already has the cluster cached (e.g., from another channel). + syncContext.execute(() -> + factory.acceptResolvedAddressFactory(resolvedAddressFactory)); + } + return Status.OK; + }); + ChildPolicyWrapper unused = factory.createOrGet("foo.google.com"); + verify(lb, times(2)).acceptResolvedAddresses(any(ResolvedAddresses.class)); + } } diff --git a/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java b/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java index f38b28d8416..23ffe6ec026 100644 --- a/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java +++ b/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java @@ -25,8 +25,10 @@ import io.grpc.internal.FakeClock; import io.grpc.rls.LruCache.EvictionListener; import io.grpc.rls.LruCache.EvictionType; +import java.util.Arrays; import java.util.Objects; import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -266,4 +268,91 @@ public int hashCode() { return Objects.hash(value, expireTime); } } + + @Test + public void testFitToLimitWithReSize() { + + Entry entry1 = new Entry("Entry1", ticker.read() + 10, 4); + Entry entry2 = new Entry("Entry2", ticker.read() + 20, 1); + Entry entry3 = new Entry("Entry3", ticker.read() + 30, 2); + + cache.cache(1, entry1); + cache.cache(2, entry2); + cache.cache(3, entry3); + + assertThat(cache.estimatedSize()).isEqualTo(2); + assertThat(cache.estimatedSizeBytes()).isEqualTo(3); + assertThat(cache.estimatedMaxSizeBytes()).isEqualTo(5); + + cache.resize(2); + assertThat(cache.estimatedSize()).isEqualTo(1); + assertThat(cache.estimatedSizeBytes()).isEqualTo(2); + assertThat(cache.estimatedMaxSizeBytes()).isEqualTo(2); + + assertThat(cache.fitToLimit()).isEqualTo(false); + } + + @Test + public void testFitToLimit() { + + TestFitToLimitEviction localCache = new TestFitToLimitEviction( + MAX_SIZE, + evictionListener, + fakeClock.getTicker() + ); + + Entry entry1 = new Entry("Entry1", ticker.read() + 10, 4); + Entry entry2 = new Entry("Entry2", ticker.read() + 20, 2); + Entry entry3 = new Entry("Entry3", ticker.read() + 30, 1); + + localCache.cache(1, entry1); + localCache.cache(2, entry2); + localCache.cache(3, entry3); + + assertThat(localCache.estimatedSize()).isEqualTo(3); + assertThat(localCache.estimatedSizeBytes()).isEqualTo(7); + assertThat(localCache.estimatedMaxSizeBytes()).isEqualTo(5); + + localCache.enableEviction(); + + assertThat(localCache.fitToLimit()).isEqualTo(true); + + assertThat(localCache.values().contains(entry1)).isFalse(); + assertThat(localCache.values().containsAll(Arrays.asList(entry2, entry3))).isTrue(); + + assertThat(localCache.estimatedSize()).isEqualTo(2); + assertThat(localCache.estimatedSizeBytes()).isEqualTo(3); + assertThat(localCache.estimatedMaxSizeBytes()).isEqualTo(5); + } + + private static class TestFitToLimitEviction extends LinkedHashLruCache { + + private boolean allowEviction = false; + + TestFitToLimitEviction( + long estimatedMaxSizeBytes, + @Nullable EvictionListener evictionListener, + Ticker ticker) { + super(estimatedMaxSizeBytes, evictionListener, ticker); + } + + @Override + protected boolean isExpired(Integer key, Entry value, long nowNanos) { + return value.expireTime - nowNanos <= 0; + } + + @Override + protected int estimateSizeOf(Integer key, Entry value) { + return value.size; + } + + @Override + protected boolean shouldInvalidateEldestEntry(Integer eldestKey, Entry eldestValue, long now) { + return allowEviction && super.shouldInvalidateEldestEntry(eldestKey, eldestValue, now); + } + + public void enableEviction() { + allowEviction = true; + } + } } diff --git a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java index f3986cb89d5..a52390743a6 100644 --- a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java @@ -42,6 +42,7 @@ import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.ForwardingChannelBuilder2; +import io.grpc.Grpc; import io.grpc.InternalManagedChannelBuilder; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; @@ -72,7 +73,6 @@ import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.FakeClock; import io.grpc.internal.JsonParser; -import io.grpc.internal.PickFirstLoadBalancerProvider; import io.grpc.internal.PickSubchannelArgsImpl; import io.grpc.internal.testing.StreamRecorder; import io.grpc.lookup.v1.RouteLookupServiceGrpc; @@ -166,15 +166,19 @@ public void setUp() { .build(); fakeRlsServerImpl.setLookupTable( ImmutableMap.of( - RouteLookupRequest.create(ImmutableMap.of( + RouteLookupRequest.create( + ImmutableMap.of( "server", "fake-bigtable.googleapis.com", "service-key", "com.google", - "method-key", "Search")), + "method-key", "Search"), + RouteLookupRequest.Reason.REASON_MISS), RouteLookupResponse.create(ImmutableList.of("wilderness"), "where are you?"), - RouteLookupRequest.create(ImmutableMap.of( + RouteLookupRequest.create( + ImmutableMap.of( "server", "fake-bigtable.googleapis.com", "service-key", "com.google", - "method-key", "Rescue")), + "method-key", "Rescue"), + RouteLookupRequest.Reason.REASON_MISS), RouteLookupResponse.create(ImmutableList.of("civilization"), "you are safe"))); rlsLb = (RlsLoadBalancer) provider.newLoadBalancer(helper); @@ -201,13 +205,21 @@ public void tearDown() { @Test public void lb_serverStatusCodeConversion() throws Exception { - deliverResolvedAddresses(); + helper.getSynchronizationContext().execute(() -> { + try { + deliverResolvedAddresses(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + assertThat(subchannels.poll()).isNotNull(); // default target + assertThat(subchannels.poll()).isNull(); + // Warm-up pick; will be queued InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); SubchannelPicker picker = pickerCaptor.getValue(); PickSubchannelArgs fakeSearchMethodArgs = newPickSubchannelArgs(fakeSearchMethod); - // Warm-up pick; will be queued PickResult res = picker.pickSubchannel(fakeSearchMethodArgs); assertThat(res.getStatus().isOk()).isTrue(); assertThat(res.getSubchannel()).isNull(); @@ -220,8 +232,7 @@ public void lb_serverStatusCodeConversion() throws Exception { subchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); res = picker.pickSubchannel(fakeSearchMethodArgs); assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.OK); - int expectedTimes = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 1 : 2; - verifyLongCounterAdd("grpc.lb.rls.target_picks", expectedTimes, 1, "wilderness", "complete"); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "wilderness", "complete"); // Check on conversion Throwable cause = new Throwable("cause"); @@ -236,7 +247,13 @@ public void lb_serverStatusCodeConversion() throws Exception { @Test public void lb_working_withDefaultTarget_rlsResponding() throws Exception { - deliverResolvedAddresses(); + helper.getSynchronizationContext().execute(() -> { + try { + deliverResolvedAddresses(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); @@ -257,7 +274,7 @@ public void lb_working_withDefaultTarget_rlsResponding() throws Exception { inOrder.verifyNoMoreInteractions(); assertThat(res.getStatus().isOk()).isTrue(); - assertThat(subchannels).hasSize(1); + assertThat(subchannels).hasSize(2); // includes fallback sub-channel FakeSubchannel searchSubchannel = subchannels.getLast(); assertThat(subchannelIsReady(searchSubchannel)).isFalse(); @@ -268,8 +285,7 @@ public void lb_working_withDefaultTarget_rlsResponding() throws Exception { res = picker.pickSubchannel(searchSubchannelArgs); assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); assertThat(res.getSubchannel()).isSameInstanceAs(searchSubchannel); - int expectedTimes = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 1 : 2; - verifyLongCounterAdd("grpc.lb.rls.target_picks", expectedTimes, 1, "wilderness", "complete"); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "wilderness", "complete"); // rescue should be pending status although the overall channel state is READY res = picker.pickSubchannel(rescueSubchannelArgs); @@ -277,7 +293,7 @@ public void lb_working_withDefaultTarget_rlsResponding() throws Exception { // other rls picker itself is ready due to first channel. assertThat(res.getStatus().isOk()).isTrue(); assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); - assertThat(subchannels).hasSize(2); + assertThat(subchannels).hasSize(3); // includes fallback sub-channel FakeSubchannel rescueSubchannel = subchannels.getLast(); // search subchannel is down, rescue subchannel is connecting @@ -355,8 +371,10 @@ public void metricsWithRealChannel() throws Exception { .build()); StreamRecorder recorder = StreamRecorder.create(); + CallOptions callOptions = CallOptions.DEFAULT + .withOption(Grpc.CALL_OPTION_CUSTOM_LABEL, "customvalue"); StreamObserver requestObserver = ClientCalls.asyncClientStreamingCall( - channel.newCall(fakeSearchMethod, CallOptions.DEFAULT), recorder); + channel.newCall(fakeSearchMethod, callOptions), recorder); requestObserver.onCompleted(); assertThat(recorder.awaitCompletion(10, TimeUnit.SECONDS)).isTrue(); assertThat(recorder.getError()).isNull(); @@ -366,7 +384,7 @@ public void metricsWithRealChannel() throws Exception { eq(1L), eq(Arrays.asList("directaddress:///fake-bigtable.googleapis.com", "localhost:8972", "defaultTarget", "complete")), - eq(Arrays.asList())); + eq(Arrays.asList("customvalue"))); } @Test @@ -393,7 +411,13 @@ public void lb_working_withoutDefaultTarget_noRlsResponse() throws Exception { public void lb_working_withDefaultTarget_noRlsResponse() throws Exception { fakeThrottler.nextResult = true; - deliverResolvedAddresses(); + helper.getSynchronizationContext().execute(() -> { + try { + deliverResolvedAddresses(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); @@ -409,7 +433,7 @@ public void lb_working_withDefaultTarget_noRlsResponse() throws Exception { inOrder.verify(helper).getMetricRecorder(); inOrder.verify(helper).getChannelTarget(); inOrder.verifyNoMoreInteractions(); - int times = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 1 : 2; + int times = 1; verifyLongCounterAdd("grpc.lb.rls.default_target_picks", times, 1, "defaultTarget", "complete"); @@ -434,8 +458,7 @@ public void lb_working_withDefaultTarget_noRlsResponse() throws Exception { (FakeSubchannel) markReadyAndGetPickResult(inOrder, searchSubchannelArgs).getSubchannel(); assertThat(searchSubchannel).isNotNull(); assertThat(searchSubchannel).isNotSameInstanceAs(fallbackSubchannel); - times = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 1 : 2; - verifyLongCounterAdd("grpc.lb.rls.target_picks", times, 1, "wilderness", "complete"); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "wilderness", "complete"); // create rescue subchannel picker.pickSubchannel(rescueSubchannelArgs); @@ -515,8 +538,7 @@ public void lb_working_withoutDefaultTarget() throws Exception { res = picker.pickSubchannel(newPickSubchannelArgs(fakeSearchMethod)); assertThat(res.getStatus().isOk()).isFalse(); assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); - int expectedTimes = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 1 : 2; - verifyLongCounterAdd("grpc.lb.rls.target_picks", expectedTimes, 1, "wilderness", "complete"); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "wilderness", "complete"); res = picker.pickSubchannel(newPickSubchannelArgs(fakeRescueMethod)); assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); @@ -535,7 +557,13 @@ public void lb_working_withoutDefaultTarget() throws Exception { @Test public void lb_nameResolutionFailed() throws Exception { - deliverResolvedAddresses(); + helper.getSynchronizationContext().execute(() -> { + try { + deliverResolvedAddresses(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); @@ -545,7 +573,7 @@ public void lb_nameResolutionFailed() throws Exception { assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); - assertThat(subchannels).hasSize(1); + assertThat(subchannels).hasSize(2); // includes fallback sub-channel FakeSubchannel searchSubchannel = subchannels.getLast(); searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); @@ -660,7 +688,7 @@ private void verifyLongCounterAdd(String name, int times, long value, verify(mockMetricRecorder, times(times)).addLongCounter( eqMetricInstrumentName(name), eq(value), eq(Lists.newArrayList(channelTarget, "localhost:8972", dataPlaneTargetLabel, pickResult)), - eq(Lists.newArrayList())); + eq(Lists.newArrayList(""))); } // This one is for verifying the failed_pick metric specifically. @@ -669,7 +697,7 @@ private void verifyFailedPicksCounterAdd(int times, long value) { verify(mockMetricRecorder, times(times)).addLongCounter( eqMetricInstrumentName("grpc.lb.rls.failed_picks"), eq(value), eq(Lists.newArrayList(channelTarget, "localhost:8972")), - eq(Lists.newArrayList())); + eq(Lists.newArrayList(""))); } @SuppressWarnings("TypeParameterUnusedInFormals") diff --git a/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java b/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java index 98b7101fd5e..82ad606c50d 100644 --- a/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java @@ -61,12 +61,14 @@ public void convert_toRequestObject() { Converter converter = new RouteLookupRequestConverter().reverse(); RlsProtoData.RouteLookupRequest requestObject = - RlsProtoData.RouteLookupRequest.create(ImmutableMap.of("key1", "val1")); + RlsProtoData.RouteLookupRequest.create(ImmutableMap.of("key1", "val1"), + RlsProtoData.RouteLookupRequest.Reason.REASON_MISS); RouteLookupRequest proto = converter.convert(requestObject); assertThat(proto.getTargetType()).isEqualTo("grpc"); assertThat(proto.getKeyMapMap()).containsExactly("key1", "val1"); + assertThat(proto.getReason()).isEqualTo(RouteLookupRequest.Reason.REASON_MISS); } @Test @@ -469,6 +471,124 @@ public void convert_jsonRlsConfig_staleAgeGivenWithoutMaxAge() throws IOExceptio } } + @Test + public void convert_jsonRlsConfig_doNotClampMaxAgeIfStaleAgeIsSet() throws IOException { + String jsonStr = "{\n" + + " \"grpcKeybuilders\": [\n" + + " {\n" + + " \"names\": [\n" + + " {\n" + + " \"service\": \"service1\",\n" + + " \"method\": \"create\"\n" + + " }\n" + + " ],\n" + + " \"headers\": [\n" + + " {\n" + + " \"key\": \"user\"," + + " \"names\": [\"User\", \"Parent\"],\n" + + " \"optional\": true\n" + + " },\n" + + " {\n" + + " \"key\": \"id\"," + + " \"names\": [\"X-Google-Id\"],\n" + + " \"optional\": true\n" + + " }\n" + + " ]\n" + + " }\n" + + " ],\n" + + " \"lookupService\": \"service1\",\n" + + " \"lookupServiceTimeout\": \"2s\",\n" + + " \"maxAge\": \"350s\",\n" + + " \"staleAge\": \"310s\",\n" + + " \"validTargets\": [\"a valid target\"]," + + " \"cacheSizeBytes\": \"1000\",\n" + + " \"defaultTarget\": \"us_east_1.cloudbigtable.googleapis.com\"\n" + + "}"; + + RouteLookupConfig expectedConfig = + RouteLookupConfig.builder() + .grpcKeybuilders(ImmutableList.of( + GrpcKeyBuilder.create( + ImmutableList.of(Name.create("service1", "create")), + ImmutableList.of( + NameMatcher.create("user", ImmutableList.of("User", "Parent")), + NameMatcher.create("id", ImmutableList.of("X-Google-Id"))), + ExtraKeys.DEFAULT, + ImmutableMap.of()))) + .lookupService("service1") + .lookupServiceTimeoutInNanos(TimeUnit.SECONDS.toNanos(2)) + .maxAgeInNanos(TimeUnit.SECONDS.toNanos(350)) // Should not be clamped + .staleAgeInNanos(TimeUnit.SECONDS.toNanos(300)) // Should be clamped to max 300s + .cacheSizeBytes(1000) + .defaultTarget("us_east_1.cloudbigtable.googleapis.com") + .build(); + + RouteLookupConfigConverter converter = new RouteLookupConfigConverter(); + @SuppressWarnings("unchecked") + Map parsedJson = (Map) JsonParser.parse(jsonStr); + RouteLookupConfig converted = converter.convert(parsedJson); + assertThat(converted).isEqualTo(expectedConfig); + } + + @Test + public void convert_jsonRlsConfig_clampMaxAgeIfStaleAgeMissing() throws IOException { + String jsonStr = "{\n" + + " \"grpcKeybuilders\": [\n" + + " {\n" + + " \"names\": [\n" + + " {\n" + + " \"service\": \"service1\",\n" + + " \"method\": \"create\"\n" + + " }\n" + + " ],\n" + + " \"headers\": [\n" + + " {\n" + + " \"key\": \"user\"," + + " \"names\": [\"User\", \"Parent\"],\n" + + " \"optional\": true\n" + + " },\n" + + " {\n" + + " \"key\": \"id\"," + + " \"names\": [\"X-Google-Id\"],\n" + + " \"optional\": true\n" + + " }\n" + + " ]\n" + + " }\n" + + " ],\n" + + " \"lookupService\": \"service1\",\n" + + " \"lookupServiceTimeout\": \"2s\",\n" + + " \"maxAge\": \"350s\",\n" // Exceeds 5m limit + + " \"validTargets\": [\"a valid target\"]," + + " \"cacheSizeBytes\": \"1000\",\n" + + " \"defaultTarget\": \"us_east_1.cloudbigtable.googleapis.com\"\n" + + "}"; + + RouteLookupConfig expectedConfig = + RouteLookupConfig.builder() + .grpcKeybuilders(ImmutableList.of( + GrpcKeyBuilder.create( + ImmutableList.of(Name.create("service1", "create")), + ImmutableList.of( + NameMatcher.create("user", ImmutableList.of("User", "Parent")), + NameMatcher.create("id", ImmutableList.of("X-Google-Id"))), + ExtraKeys.DEFAULT, + ImmutableMap.of()))) + .lookupService("service1") + .lookupServiceTimeoutInNanos(TimeUnit.SECONDS.toNanos(2)) + // Should be clamped to 300s (5m) because staleAge is missing + .maxAgeInNanos(TimeUnit.MINUTES.toNanos(5)) + .staleAgeInNanos(TimeUnit.MINUTES.toNanos(5)) + .cacheSizeBytes(1000) + .defaultTarget("us_east_1.cloudbigtable.googleapis.com") + .build(); + + RouteLookupConfigConverter converter = new RouteLookupConfigConverter(); + @SuppressWarnings("unchecked") + Map parsedJson = (Map) JsonParser.parse(jsonStr); + RouteLookupConfig converted = converter.convert(parsedJson); + assertThat(converted).isEqualTo(expectedConfig); + } + @Test public void convert_jsonRlsConfig_keyBuilderWithoutName() throws IOException { String jsonStr = "{\n" diff --git a/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java b/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java index 6ee2c01af8a..2b900994ed9 100644 --- a/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java @@ -26,7 +26,6 @@ import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; import io.grpc.rls.RlsProtoData.RouteLookupConfig; -import io.grpc.rls.RlsProtoData.RouteLookupRequest; import java.util.concurrent.TimeUnit; import org.junit.Test; import org.junit.runner.RunWith; @@ -82,8 +81,9 @@ public void create_pathMatches() { metadata.put(Metadata.Key.of("X-Google-Id", Metadata.ASCII_STRING_MARSHALLER), "123"); metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - RouteLookupRequest request = factory.create("com.google.service1", "Create", metadata); - assertThat(request.keyMap()).containsExactly( + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + factory.create("com.google.service1", "Create", metadata); + assertThat(routeLookupRequestKey.keyMap()).containsExactly( "user", "test", "id", "123", "server-1", "bigtable.googleapis.com", @@ -97,9 +97,10 @@ public void create_pathFallbackMatches() { metadata.put(Metadata.Key.of("Password", Metadata.ASCII_STRING_MARSHALLER), "hunter2"); metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - RouteLookupRequest request = factory.create("com.google.service1" , "Update", metadata); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + factory.create("com.google.service1" , "Update", metadata); - assertThat(request.keyMap()).containsExactly( + assertThat(routeLookupRequestKey.keyMap()).containsExactly( "user", "test", "password", "hunter2", "service-2", "com.google.service1", @@ -113,9 +114,10 @@ public void create_pathFallbackMatches_optionalHeaderMissing() { metadata.put(Metadata.Key.of("X-Google-Id", Metadata.ASCII_STRING_MARSHALLER), "123"); metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - RouteLookupRequest request = factory.create("com.google.service1", "Update", metadata); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + factory.create("com.google.service1", "Update", metadata); - assertThat(request.keyMap()).containsExactly( + assertThat(routeLookupRequestKey.keyMap()).containsExactly( "user", "test", "service-2", "com.google.service1", "const-key-2", "const-value-2"); @@ -128,8 +130,9 @@ public void create_unknownPath() { metadata.put(Metadata.Key.of("X-Google-Id", Metadata.ASCII_STRING_MARSHALLER), "123"); metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - RouteLookupRequest request = factory.create("abc.def.service999", "Update", metadata); - assertThat(request.keyMap()).isEmpty(); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + factory.create("abc.def.service999", "Update", metadata); + assertThat(routeLookupRequestKey.keyMap()).isEmpty(); } @Test @@ -139,9 +142,10 @@ public void create_noMethodInRlsConfig() { metadata.put(Metadata.Key.of("X-Google-Id", Metadata.ASCII_STRING_MARSHALLER), "123"); metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - RouteLookupRequest request = factory.create("com.google.service3", "Update", metadata); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + factory.create("com.google.service3", "Update", metadata); - assertThat(request.keyMap()).containsExactly( + assertThat(routeLookupRequestKey.keyMap()).containsExactly( "user", "test", "const-key-4", "const-value-4"); } } diff --git a/s2a/BUILD.bazel b/s2a/BUILD.bazel new file mode 100644 index 00000000000..34387206ba5 --- /dev/null +++ b/s2a/BUILD.bazel @@ -0,0 +1,93 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + +java_library( + name = "s2a_channel_pool", + srcs = glob([ + "src/main/java/io/grpc/s2a/internal/channel/*.java", + ]), + deps = [ + "//api", + "//core", + "//core:internal", + "//netty", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("org.checkerframework:checker-qual"), + artifact("io.netty:netty-common"), + artifact("io.netty:netty-transport"), + ], +) + +java_library( + name = "s2a_identity", + srcs = ["src/main/java/io/grpc/s2a/internal/handshaker/S2AIdentity.java"], + deps = [ + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("com.google.s2a.proto.v2:s2a-proto"), + ], +) + +java_library( + name = "token_manager", + srcs = glob([ + "src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/*.java", + ]), + deps = [ + ":s2a_identity", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), + ], +) + +java_library( + name = "s2a_handshaker", + srcs = [ + "src/main/java/io/grpc/s2a/internal/handshaker/ConnectionClosedException.java", + "src/main/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanisms.java", + "src/main/java/io/grpc/s2a/internal/handshaker/ProtoUtil.java", + "src/main/java/io/grpc/s2a/internal/handshaker/S2AConnectionException.java", + "src/main/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethod.java", + "src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java", + "src/main/java/io/grpc/s2a/internal/handshaker/S2AStub.java", + "src/main/java/io/grpc/s2a/internal/handshaker/S2ATrustManager.java", + "src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java", + ], + deps = [ + ":s2a_identity", + ":token_manager", + "//api", + "//core:internal", + "//netty", + "//stub", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("com.google.s2a.proto.v2:s2a-proto"), + artifact("org.checkerframework:checker-qual"), + "@com_google_protobuf//:protobuf_java", + artifact("io.netty:netty-common"), + artifact("io.netty:netty-handler"), + artifact("io.netty:netty-transport"), + ], +) + +java_library( + name = "s2a", + srcs = ["src/main/java/io/grpc/s2a/S2AChannelCredentials.java"], + visibility = ["//visibility:public"], + deps = [ + ":s2a_channel_pool", + ":s2a_handshaker", + ":s2a_identity", + "//api", + "//core:internal", + "//netty", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("org.checkerframework:checker-qual"), + ], +) diff --git a/s2a/build.gradle b/s2a/build.gradle new file mode 100644 index 00000000000..c46993ec9c8 --- /dev/null +++ b/s2a/build.gradle @@ -0,0 +1,120 @@ +plugins { + id "java-library" + id "maven-publish" + + id "com.google.osdetector" + id "com.google.protobuf" + id "com.gradleup.shadow" + id "ru.vyarus.animalsniffer" +} + +description = "gRPC: S2A" + +dependencies { + implementation libraries.s2a.proto + implementation 'org.checkerframework:checker-qual:3.49.5' + + api project(':grpc-api') + implementation project(':grpc-stub'), + project(':grpc-protobuf'), + project(':grpc-core'), + libraries.protobuf.java, + libraries.guava.jre // JRE required by protobuf-java-util from grpclb + def nettyDependency = implementation project(':grpc-netty') + + shadow configurations.implementation.getDependencies().minus(nettyDependency) + shadow project(path: ':grpc-netty-shaded', configuration: 'shadow') + + testImplementation project(':grpc-benchmarks'), + project(':grpc-testing'), + project(':grpc-testing-proto'), + testFixtures(project(':grpc-core')), + libraries.guava + + testImplementation 'com.google.truth:truth:1.4.2' + testImplementation 'com.google.truth.extensions:truth-proto-extension:1.4.2' + testImplementation libraries.guava.testlib + + testRuntimeOnly libraries.netty.tcnative, + libraries.netty.tcnative.classes + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "linux-x86_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "linux-aarch_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "osx-x86_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "osx-aarch_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "windows-x86_64" + } + } + + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } +} + +configureProtoCompilation() + +tasks.named("javadoc").configure { + exclude 'io/grpc/s2a/**' +} + +tasks.named("jar").configure { + // Must use a different archiveClassifier to avoid conflicting with shadowJar + archiveClassifier = 'original' + manifest { + attributes('Automatic-Module-Name': 'io.grpc.s2a') + } +} + +// We want to use grpc-netty-shaded instead of grpc-netty. But we also want our +// source to work with Bazel, so we rewrite the code as part of the build. +tasks.named("shadowJar").configure { + archiveClassifier = null + dependencies { + exclude(dependency {true}) + } + relocate 'io.grpc.netty', 'io.grpc.netty.shaded.io.grpc.netty' + relocate 'io.netty', 'io.grpc.netty.shaded.io.netty' +} + +plugins.withId('maven-publish') { +publishing { + publications { + maven(MavenPublication) { + // We want this to throw an exception if it isn't working + def originalJar = artifacts.find { dep -> dep.classifier == 'original'} + artifacts.remove(originalJar) + + pom.withXml { + def dependenciesNode = new Node(null, 'dependencies') + project.configurations.shadow.allDependencies.each { dep -> + def dependencyNode = dependenciesNode.appendNode('dependency') + dependencyNode.appendNode('groupId', dep.group) + dependencyNode.appendNode('artifactId', dep.name) + dependencyNode.appendNode('version', dep.version) + dependencyNode.appendNode('scope', 'compile') + } + asNode().dependencies[0].replaceNode(dependenciesNode) + } + } + } +} +} diff --git a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java new file mode 100644 index 00000000000..4be32475205 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java @@ -0,0 +1,135 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Strings.isNullOrEmpty; + +import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.ExperimentalApi; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.netty.InternalNettyChannelCredentials; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory; +import io.grpc.s2a.internal.handshaker.S2AStub; +import javax.annotation.concurrent.NotThreadSafe; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Configures gRPC to use S2A for transport security when establishing a secure channel. Only for + * use on the client side of a gRPC connection. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11533") +public final class S2AChannelCredentials { + /** + * Creates a channel credentials builder for establishing an S2A-secured connection. + * + * @param s2aAddress the address of the S2A server used to secure the connection. + * @param s2aChannelCredentials the credentials to be used when connecting to the S2A. + * @return a {@code S2AChannelCredentials.Builder} instance. + */ + public static Builder newBuilder(String s2aAddress, ChannelCredentials s2aChannelCredentials) { + checkArgument(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); + checkNotNull(s2aChannelCredentials, "S2A channel credentials must not be null"); + return new Builder(s2aAddress, s2aChannelCredentials); + } + + /** Builds an {@code S2AChannelCredentials} instance. */ + @NotThreadSafe + public static final class Builder { + private final String s2aAddress; + private final ChannelCredentials s2aChannelCredentials; + private @Nullable S2AIdentity localIdentity = null; + private @Nullable S2AStub stub = null; + + Builder(String s2aAddress, ChannelCredentials s2aChannelCredentials) { + this.s2aAddress = s2aAddress; + this.s2aChannelCredentials = s2aChannelCredentials; + } + + /** + * Sets the local identity of the client in the form of a SPIFFE ID. The client may set at most + * 1 local identity. If no local identity is specified, then the S2A chooses a default local + * identity, if one exists. + */ + @CanIgnoreReturnValue + public Builder setLocalSpiffeId(String localSpiffeId) { + checkNotNull(localSpiffeId); + checkArgument(localIdentity == null, "localIdentity is already set."); + localIdentity = S2AIdentity.fromSpiffeId(localSpiffeId); + return this; + } + + /** + * Sets the local identity of the client in the form of a hostname. The client may set at most 1 + * local identity. If no local identity is specified, then the S2A chooses a default local + * identity, if one exists. + */ + @CanIgnoreReturnValue + public Builder setLocalHostname(String localHostname) { + checkNotNull(localHostname); + checkArgument(localIdentity == null, "localIdentity is already set."); + localIdentity = S2AIdentity.fromHostname(localHostname); + return this; + } + + /** + * Sets the local identity of the client in the form of a UID. The client may set at most 1 + * local identity. If no local identity is specified, then the S2A chooses a default local + * identity, if one exists. + */ + @CanIgnoreReturnValue + public Builder setLocalUid(String localUid) { + checkNotNull(localUid); + checkArgument(localIdentity == null, "localIdentity is already set."); + localIdentity = S2AIdentity.fromUid(localUid); + return this; + } + + /** + * Sets the stub to use to communicate with S2A. This is only used for testing that the + * stream to S2A gets closed. + */ + @VisibleForTesting + Builder setStub(S2AStub stub) { + checkNotNull(stub); + this.stub = stub; + return this; + } + + public ChannelCredentials build() { + return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory()); + } + + InternalProtocolNegotiator.ClientFactory buildProtocolNegotiatorFactory() { + ObjectPool s2aChannelPool = + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials)); + checkNotNull(s2aChannelPool, "s2aChannelPool"); + return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool, stub); + } + } + + private S2AChannelCredentials() {} +} diff --git a/s2a/src/main/java/io/grpc/s2a/internal/channel/S2AHandshakerServiceChannel.java b/s2a/src/main/java/io/grpc/s2a/internal/channel/S2AHandshakerServiceChannel.java new file mode 100644 index 00000000000..8453268efc0 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/channel/S2AHandshakerServiceChannel.java @@ -0,0 +1,107 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.channel; + +import static com.google.common.base.Preconditions.checkNotNull; +import static java.util.concurrent.TimeUnit.SECONDS; + +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.netty.NettyChannelBuilder; +import javax.annotation.concurrent.ThreadSafe; + +/** + * Provides APIs for managing gRPC channels to an S2A server. Each channel is local and plaintext. + * If credentials are provided, they are used to secure the channel. + * + *

This is done as follows: for an S2A server, provides an implementation of gRPC's {@link + * SharedResourceHolder.Resource} interface called a {@code Resource}. A {@code + * Resource} is a factory for creating gRPC channels to the S2A server at a given address, + * and a channel must be returned to the {@code Resource} when it is no longer needed. + * + *

Typical usage pattern is below: + * + *

{@code
+ * Resource resource = S2AHandshakerServiceChannel.getChannelResource("localhost:1234",
+ * creds);
+ * Channel channel = resource.create();
+ * // Send an RPC over the channel to the S2A server running at localhost:1234.
+ * resource.close(channel);
+ * }
+ */ +@ThreadSafe +public final class S2AHandshakerServiceChannel { + + /** + * Returns a {@link SharedResourceHolder.Resource} instance for managing channels to an S2A server + * running at {@code s2aAddress}. + * + * @param s2aAddress the address of the S2A, typically in the format {@code host:port}. + * @param s2aChannelCredentials the credentials to use when establishing a connection to the S2A. + * @return a {@link ChannelResource} instance that manages a {@link Channel} to the S2A server + * running at {@code s2aAddress}. + */ + public static Resource getChannelResource( + String s2aAddress, ChannelCredentials s2aChannelCredentials) { + checkNotNull(s2aAddress); + return new ChannelResource(s2aAddress, s2aChannelCredentials); + } + + /** + * Defines how to create and destroy a {@link Channel} instance that uses shared resources. A + * channel created by {@code ChannelResource} is a plaintext, local channel to the service running + * at {@code targetAddress}. + */ + private static class ChannelResource implements Resource { + private final String targetAddress; + private final ChannelCredentials channelCredentials; + + public ChannelResource(String targetAddress, ChannelCredentials channelCredentials) { + this.targetAddress = targetAddress; + this.channelCredentials = channelCredentials; + } + + /** + * Creates a {@code ManagedChannel} instance to the service running at {@code + * targetAddress}. + */ + @Override + public Channel create() { + return NettyChannelBuilder.forTarget(targetAddress, channelCredentials) + .directExecutor() + .idleTimeout(5, SECONDS) + .build(); + } + + /** Destroys a {@code ManagedChannel} instance. */ + @Override + public void close(Channel instanceChannel) { + checkNotNull(instanceChannel); + ManagedChannel channel = (ManagedChannel) instanceChannel; + channel.shutdownNow(); + } + + @Override + public String toString() { + return "grpc-s2a-channel"; + } + } + + private S2AHandshakerServiceChannel() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ConnectionClosedException.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ConnectionClosedException.java new file mode 100644 index 00000000000..d6f1aa70f7c --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ConnectionClosedException.java @@ -0,0 +1,27 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import java.io.IOException; + +/** Indicates that a connection has been closed. */ +@SuppressWarnings("serial") // This class is never serialized. +final class ConnectionClosedException extends IOException { + public ConnectionClosedException(String errorMessage) { + super(errorMessage); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanisms.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanisms.java new file mode 100644 index 00000000000..cf632418e66 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanisms.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import com.google.errorprone.annotations.Immutable; +import com.google.s2a.proto.v2.AuthenticationMechanism; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.grpc.s2a.internal.handshaker.tokenmanager.AccessTokenManager; +import java.util.Optional; + +/** Retrieves the authentication mechanism for a given local identity. */ +@Immutable +final class GetAuthenticationMechanisms { + static final Optional TOKEN_MANAGER = AccessTokenManager.create(); + + /** + * Retrieves the authentication mechanism for a given local identity. + * + * @param localIdentity the identity for which to fetch a token. + * @param tokenManager the token manager to use for fetching tokens. + * @return an {@link AuthenticationMechanism} for the given local identity. + */ + static Optional getAuthMechanism(Optional localIdentity, + Optional tokenManager) { + if (!tokenManager.isPresent()) { + return Optional.empty(); + } + AccessTokenManager manager = tokenManager.get(); + // If no identity is provided, fetch the default access token and DO NOT attach an identity + // to the request. + if (!localIdentity.isPresent()) { + return Optional.of( + AuthenticationMechanism.newBuilder().setToken(manager.getDefaultToken()).build()); + } else { + // Fetch an access token for the provided identity. + return Optional.of( + AuthenticationMechanism.newBuilder() + .setIdentity(localIdentity.get().getIdentity()) + .setToken(manager.getToken(localIdentity.get())) + .build()); + } + } + + private GetAuthenticationMechanisms() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ProtoUtil.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ProtoUtil.java new file mode 100644 index 00000000000..0526ec154f9 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ProtoUtil.java @@ -0,0 +1,78 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableSet; +import com.google.s2a.proto.v2.TLSVersion; + +/** Converts proto messages to Netty strings. */ +final class ProtoUtil { + + /** + * Converts a {@link TLSVersion} object to its {@link String} representation. + * + * @param tlsVersion the {@link TLSVersion} object to be converted. + * @return a {@link String} representation of the TLS version. + * @throws IllegalArgumentException if the {@code tlsVersion} is not one of + * the supported TLS versions. + */ + @VisibleForTesting + static String convertTlsProtocolVersion(TLSVersion tlsVersion) { + switch (tlsVersion) { + case TLS_VERSION_1_3: + return "TLSv1.3"; + case TLS_VERSION_1_2: + return "TLSv1.2"; + case TLS_VERSION_1_1: + return "TLSv1.1"; + case TLS_VERSION_1_0: + return "TLSv1"; + default: + throw new IllegalArgumentException( + String.format("TLS version %d is not supported.", tlsVersion.getNumber())); + } + } + + /** + * Builds a set of strings representing all {@link TLSVersion}s between {@code minTlsVersion} and + * {@code maxTlsVersion}. + */ + static ImmutableSet buildTlsProtocolVersionSet( + TLSVersion minTlsVersion, TLSVersion maxTlsVersion) { + ImmutableSet.Builder tlsVersions = ImmutableSet.builder(); + for (TLSVersion tlsVersion : TLSVersion.values()) { + int versionNumber; + try { + versionNumber = tlsVersion.getNumber(); + } catch (IllegalArgumentException e) { + continue; + } + if (versionNumber >= minTlsVersion.getNumber() + && versionNumber <= maxTlsVersion.getNumber()) { + try { + tlsVersions.add(convertTlsProtocolVersion(tlsVersion)); + } catch (IllegalArgumentException e) { + continue; + } + } + } + return tlsVersions.build(); + } + + private ProtoUtil() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AConnectionException.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AConnectionException.java new file mode 100644 index 00000000000..9b6c244751b --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AConnectionException.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +/** Exception that denotes a runtime error that was encountered when talking to the S2A server. */ +@SuppressWarnings("serial") // This class is never serialized. +public class S2AConnectionException extends RuntimeException { + S2AConnectionException(String message) { + super(message); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AIdentity.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AIdentity.java new file mode 100644 index 00000000000..f4d6b88ce45 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AIdentity.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.errorprone.annotations.ThreadSafe; +import com.google.s2a.proto.v2.Identity; + +/** + * Stores an identity in such a way that it can be sent to the S2A handshaker service. The identity + * may be formatted as a SPIFFE ID or as a hostname. + */ +@ThreadSafe +public final class S2AIdentity { + private final Identity identity; + + /** Returns an {@link S2AIdentity} instance with SPIFFE ID set to {@code spiffeId}. */ + public static S2AIdentity fromSpiffeId(String spiffeId) { + checkNotNull(spiffeId); + return new S2AIdentity(Identity.newBuilder().setSpiffeId(spiffeId).build()); + } + + /** Returns an {@link S2AIdentity} instance with hostname set to {@code hostname}. */ + public static S2AIdentity fromHostname(String hostname) { + checkNotNull(hostname); + return new S2AIdentity(Identity.newBuilder().setHostname(hostname).build()); + } + + /** Returns an {@link S2AIdentity} instance with UID set to {@code uid}. */ + public static S2AIdentity fromUid(String uid) { + checkNotNull(uid); + return new S2AIdentity(Identity.newBuilder().setUid(uid).build()); + } + + /** Returns an {@link S2AIdentity} instance with {@code identity} set. */ + public static S2AIdentity fromIdentity(Identity identity) { + return new S2AIdentity(identity == null ? Identity.getDefaultInstance() : identity); + } + + private S2AIdentity(Identity identity) { + this.identity = identity; + } + + /** Returns the proto {@link Identity} representation of this identity instance. */ + public Identity getIdentity() { + return identity; + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethod.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethod.java new file mode 100644 index 00000000000..1a5c37eb989 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethod.java @@ -0,0 +1,147 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import com.google.s2a.proto.v2.OffloadPrivateKeyOperationReq; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import com.google.s2a.proto.v2.SignatureAlgorithm; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslPrivateKeyMethod; +import java.io.IOException; +import java.util.Optional; +import javax.annotation.concurrent.NotThreadSafe; +import javax.net.ssl.SSLEngine; + +/** + * Handles requests on signing bytes with a private key designated by {@code stub}. + * + *

This is done by sending the to-be-signed bytes to an S2A server (designated by {@code stub}) + * and read the signature from the server. + * + *

OpenSSL libraries must be appropriately initialized before using this class. One possible way + * to initialize OpenSSL library is to call {@code + * GrpcSslContexts.configure(SslContextBuilder.forClient());}. + */ +@NotThreadSafe +final class S2APrivateKeyMethod implements OpenSslPrivateKeyMethod { + private final S2AStub stub; + private final Optional localIdentity; + private static final ImmutableMap + OPENSSL_TO_S2A_SIGNATURE_ALGORITHM_MAP = + ImmutableMap.of( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA256, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA384, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA512, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384, + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512, + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA384, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA512); + + public static S2APrivateKeyMethod create(S2AStub stub, Optional localIdentity) { + checkNotNull(stub); + return new S2APrivateKeyMethod(stub, localIdentity); + } + + private S2APrivateKeyMethod(S2AStub stub, Optional localIdentity) { + this.stub = stub; + this.localIdentity = localIdentity; + } + + /** + * Converts the signature algorithm to an enum understood by S2A. + * + * @param signatureAlgorithm the int representation of the signature algorithm define by {@code + * OpenSslPrivateKeyMethod}. + * @return the signature algorithm enum defined by S2A proto. + * @throws UnsupportedOperationException if the algorithm is not supported by S2A. + */ + @VisibleForTesting + static SignatureAlgorithm convertOpenSslSignAlgToS2ASignAlg(int signatureAlgorithm) { + SignatureAlgorithm sig = OPENSSL_TO_S2A_SIGNATURE_ALGORITHM_MAP.get(signatureAlgorithm); + if (sig == null) { + throw new UnsupportedOperationException( + String.format("Signature Algorithm %d is not supported.", signatureAlgorithm)); + } + return sig; + } + + /** + * Signs the input bytes by sending the request to the S2A srever. + * + * @param engine not used. + * @param signatureAlgorithm the {@link OpenSslPrivateKeyMethod}'s signature algorithm + * representation + * @param input the bytes to be signed. + * @return the signature of the {@code input}. + * @throws IOException if the connection to the S2A server is corrupted. + * @throws InterruptedException if the connection to the S2A server is interrupted. + * @throws S2AConnectionException if the response from the S2A server does not contain valid data. + */ + @Override + public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) + throws IOException, InterruptedException { + checkArgument(input.length > 0, "No bytes to sign."); + SignatureAlgorithm s2aSignatureAlgorithm = + convertOpenSslSignAlgToS2ASignAlg(signatureAlgorithm); + SessionReq.Builder reqBuilder = + SessionReq.newBuilder() + .setOffloadPrivateKeyOperationReq( + OffloadPrivateKeyOperationReq.newBuilder() + .setOperation(OffloadPrivateKeyOperationReq.PrivateKeyOperation.SIGN) + .setSignatureAlgorithm(s2aSignatureAlgorithm) + .setRawBytes(ByteString.copyFrom(input))); + if (localIdentity.isPresent()) { + reqBuilder.setLocalIdentity(localIdentity.get().getIdentity()); + } + + SessionResp resp = stub.send(reqBuilder.build()); + + if (resp.hasStatus() && resp.getStatus().getCode() != 0) { + throw new S2AConnectionException( + String.format( + "Error occurred in response from S2A, error code: %d, error message: \"%s\".", + resp.getStatus().getCode(), resp.getStatus().getDetails())); + } + if (!resp.hasOffloadPrivateKeyOperationResp()) { + throw new S2AConnectionException("No valid response received from S2A."); + } + return resp.getOffloadPrivateKeyOperationResp().getOutBytes().toByteArray(); + } + + @Override + public byte[] decrypt(SSLEngine engine, byte[] input) { + throw new UnsupportedOperationException("decrypt is not supported."); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java new file mode 100644 index 00000000000..9dcbdcf0509 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java @@ -0,0 +1,282 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Strings.isNullOrEmpty; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.net.HostAndPort; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.s2a.proto.v2.S2AServiceGrpc; +import io.grpc.Channel; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiators; +import io.grpc.netty.InternalProtocolNegotiators.ProtocolNegotiationHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.ssl.SslContext; +import io.netty.util.AsciiString; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.Executors; +import javax.annotation.Nullable; + +/** Factory for performing negotiation of a secure channel using the S2A. */ +@ThreadSafe +public final class S2AProtocolNegotiatorFactory { + @VisibleForTesting static final int DEFAULT_PORT = 443; + private static final AsciiString SCHEME = AsciiString.of("https"); + + /** + * Creates a {@code S2AProtocolNegotiatorFactory} configured for a client to establish secure + * connections using the S2A. + * + * @param localIdentity the identity of the client; if none is provided, the S2A will use the + * client's default identity. + * @param s2aChannelPool a pool of shared channels that can be used to connect to the S2A. + * @param stub the stub to use to communicate with S2A. If none is provided the channelPool + * will be used to create the stub. This is exposed for verifying the stream to S2A gets + * closed in tests. + * @return a factory for creating a client-side protocol negotiator. + */ + public static InternalProtocolNegotiator.ClientFactory createClientFactory( + @Nullable S2AIdentity localIdentity, ObjectPool s2aChannelPool, + @Nullable S2AStub stub) { + checkNotNull(s2aChannelPool, "S2A channel pool should not be null."); + return new S2AClientProtocolNegotiatorFactory(localIdentity, s2aChannelPool, stub); + } + + static final class S2AClientProtocolNegotiatorFactory + implements InternalProtocolNegotiator.ClientFactory { + private final @Nullable S2AIdentity localIdentity; + private final ObjectPool channelPool; + private final @Nullable S2AStub stub; + + S2AClientProtocolNegotiatorFactory( + @Nullable S2AIdentity localIdentity, ObjectPool channelPool, + @Nullable S2AStub stub) { + this.localIdentity = localIdentity; + this.channelPool = channelPool; + this.stub = stub; + } + + @Override + public ProtocolNegotiator newNegotiator() { + return S2AProtocolNegotiator.createForClient(channelPool, localIdentity, stub); + } + + @Override + public int getDefaultPort() { + return DEFAULT_PORT; + } + } + + /** Negotiates the TLS handshake using S2A. */ + @VisibleForTesting + static final class S2AProtocolNegotiator implements ProtocolNegotiator { + + private final ObjectPool channelPool; + private @Nullable Channel channel = null; + private final Optional localIdentity; + private final @Nullable S2AStub stub; + private final ListeningExecutorService service = + MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1)); + + static S2AProtocolNegotiator createForClient( + ObjectPool channelPool, @Nullable S2AIdentity localIdentity, + @Nullable S2AStub stub) { + checkNotNull(channelPool, "Channel pool should not be null."); + if (localIdentity == null) { + return new S2AProtocolNegotiator(channelPool, Optional.empty(), stub); + } else { + return new S2AProtocolNegotiator(channelPool, Optional.of(localIdentity), stub); + } + } + + @VisibleForTesting + static @Nullable String getHostNameFromAuthority(@Nullable String authority) { + if (authority == null) { + return null; + } + return HostAndPort.fromString(authority).getHost(); + } + + private S2AProtocolNegotiator(ObjectPool channelPool, + Optional localIdentity, @Nullable S2AStub stub) { + this.channelPool = channelPool; + this.localIdentity = localIdentity; + this.stub = stub; + if (this.stub == null) { + this.channel = channelPool.getObject(); + } + } + + @Override + public AsciiString scheme() { + return SCHEME; + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + checkNotNull(grpcHandler, "grpcHandler should not be null."); + String hostname = getHostNameFromAuthority(grpcHandler.getAuthority()); + checkArgument(!isNullOrEmpty(hostname), "hostname should not be null or empty."); + return new S2AProtocolNegotiationHandler( + grpcHandler, channel, localIdentity, hostname, service, stub); + } + + @Override + public void close() { + service.shutdown(); + if (channel != null) { + channelPool.returnObject(channel); + } + } + } + + @VisibleForTesting + static class BufferReadsHandler extends ChannelInboundHandlerAdapter { + private final List reads = new ArrayList<>(); + private boolean readComplete; + + public List getReads() { + return reads; + } + + @Override + public void channelRead(ChannelHandlerContext unused, Object msg) { + reads.add(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext unused) { + readComplete = true; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + for (Object msg : reads) { + super.channelRead(ctx, msg); + } + if (readComplete) { + super.channelReadComplete(ctx); + } + } + } + + private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler { + private final @Nullable Channel channel; + private final Optional localIdentity; + private final String hostname; + private final GrpcHttp2ConnectionHandler grpcHandler; + private final ListeningExecutorService service; + private final @Nullable S2AStub stub; + + private S2AProtocolNegotiationHandler( + GrpcHttp2ConnectionHandler grpcHandler, + Channel channel, + Optional localIdentity, + String hostname, + ListeningExecutorService service, + @Nullable S2AStub stub) { + super( + // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' + // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior + // here and then manually add 'next' when we call fireProtocolNegotiationEvent() + new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + ctx.pipeline().remove(this); + } + }, + grpcHandler.getNegotiationLogger()); + this.grpcHandler = grpcHandler; + this.channel = channel; + this.localIdentity = localIdentity; + this.hostname = hostname; + checkNotNull(service, "service should not be null."); + this.service = service; + this.stub = stub; + } + + @Override + protected void handlerAdded0(ChannelHandlerContext ctx) { + // Buffer all reads until the TLS Handler is added. + BufferReadsHandler bufferReads = new BufferReadsHandler(); + ctx.pipeline().addBefore(ctx.name(), /* name= */ null, bufferReads); + + S2AStub s2aStub; + if (this.stub == null) { + checkNotNull(channel, "Channel to S2A should not be null"); + s2aStub = S2AStub.newInstance(S2AServiceGrpc.newStub(channel)); + } else { + s2aStub = this.stub; + } + + ListenableFuture sslContextFuture = + service.submit(() -> SslContextFactory.createForClient(s2aStub, hostname, localIdentity)); + Futures.addCallback( + sslContextFuture, + new FutureCallback() { + @Override + public void onSuccess(SslContext sslContext) { + ChannelHandler handler = + InternalProtocolNegotiators.tls( + sslContext, + SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR), + com.google.common.base.Optional.of(new Runnable() { + @Override + public void run() { + s2aStub.close(); + } + }), + null, null) + .newHandler(grpcHandler); + + // Delegate the rest of the handshake to the TLS handler. and remove the + // bufferReads handler. + ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler); + fireProtocolNegotiationEvent(ctx); + ctx.pipeline().remove(bufferReads); + } + + @Override + public void onFailure(Throwable t) { + ctx.fireExceptionCaught(t); + } + }, + service); + } + } + + private S2AProtocolNegotiatorFactory() {} +} diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AStub.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AStub.java new file mode 100644 index 00000000000..37236f26f4b --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AStub.java @@ -0,0 +1,245 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Verify.verify; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.google.common.annotations.VisibleForTesting; +import com.google.s2a.proto.v2.S2AServiceGrpc; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.util.Optional; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.concurrent.NotThreadSafe; + +/** Reads and writes messages to and from the S2A. */ +@NotThreadSafe +public class S2AStub implements AutoCloseable { + private static final Logger logger = Logger.getLogger(S2AStub.class.getName()); + private static final long HANDSHAKE_RPC_DEADLINE_SECS = 20; + private final StreamObserver reader = new Reader(); + private final BlockingQueue responses = new ArrayBlockingQueue<>(10); + private S2AServiceGrpc.S2AServiceStub serviceStub; + private StreamObserver writer; + private long deadlineSeconds = HANDSHAKE_RPC_DEADLINE_SECS; + private boolean doneReading = false; + private boolean doneWriting = false; + private boolean isClosed = false; + + @VisibleForTesting + public static S2AStub newInstance(S2AServiceGrpc.S2AServiceStub serviceStub) { + checkNotNull(serviceStub); + return new S2AStub(serviceStub); + } + + @VisibleForTesting + static S2AStub newInstanceWithDeadline( + S2AServiceGrpc.S2AServiceStub serviceStub, long deadlineSeconds) { + checkNotNull(serviceStub); + checkArgument(deadlineSeconds > 0); + return new S2AStub(serviceStub, deadlineSeconds); + } + + @VisibleForTesting + static S2AStub newInstanceForTesting(StreamObserver writer) { + checkNotNull(writer); + return new S2AStub(writer); + } + + private S2AStub(S2AServiceGrpc.S2AServiceStub serviceStub) { + this.serviceStub = serviceStub; + } + + private S2AStub(S2AServiceGrpc.S2AServiceStub serviceStub, long deadlineSeconds) { + this.serviceStub = serviceStub; + this.deadlineSeconds = deadlineSeconds; + } + + private S2AStub(StreamObserver writer) { + this.writer = writer; + } + + @VisibleForTesting + StreamObserver getReader() { + return reader; + } + + @VisibleForTesting + BlockingQueue getResponses() { + return responses; + } + + /** + * Sends a request and returns the response. Caller must wait until this method executes prior to + * calling it again. If this method throws {@code ConnectionClosedException}, then it should not + * be called again, and both {@code reader} and {@code writer} are closed. + * + * @param req the {@code SessionReq} message to be sent to the S2A server. + * @return the {@code SessionResp} message received from the S2A server. + * @throws ConnectionClosedException if {@code reader} or {@code writer} calls their {@code + * onCompleted} method. + * @throws IOException if an unexpected response is received, or if the {@code reader} or {@code + * writer} calls their {@code onError} method. + */ + @SuppressWarnings("CheckReturnValue") + public SessionResp send(SessionReq req) throws IOException, InterruptedException { + if (doneWriting && doneReading) { + logger.log(Level.INFO, "Stream to the S2A is closed."); + throw new ConnectionClosedException("Stream to the S2A is closed."); + } + createWriterIfNull(); + if (!responses.isEmpty()) { + IOException exception = null; + try { + responses.take().getResultOrThrow(); + } catch (IOException e) { + exception = e; + } + responses.clear(); + if (exception != null) { + throw new IOException( + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable.", exception); + } else { + throw new IOException("Received an unexpected response from a host at the S2A's address."); + } + } + try { + writer.onNext(req); + } catch (RuntimeException e) { + writer.onError(e); + responses.add(Result.createWithThrowable(e)); + } + try { + return responses.take().getResultOrThrow(); + } catch (ConnectionClosedException e) { + // A ConnectionClosedException is thrown by getResultOrThrow when reader calls its + // onCompleted method. The close method is called to also close the writer, and then the + // ConnectionClosedException is re-thrown in order to indicate to the caller that send + // should not be called again. + close(); + throw e; + } + } + + @Override + public void close() { + if (doneWriting && doneReading) { + return; + } + verify(!doneWriting); + doneReading = true; + doneWriting = true; + if (writer != null) { + writer.onCompleted(); + } + isClosed = true; + } + + public boolean isClosed() { + return isClosed; + } + + /** Create a new writer if the writer is null. */ + private void createWriterIfNull() { + if (writer == null) { + writer = + serviceStub + .withWaitForReady() + .withDeadlineAfter(deadlineSeconds, SECONDS) + .setUpSession(reader); + } + } + + private class Reader implements StreamObserver { + /** + * Places a {@code SessionResp} message in the {@code responses} queue, or an {@code + * IOException} if reading is complete. + * + * @param resp the {@code SessionResp} message received from the S2A handshaker module. + */ + @Override + public void onNext(SessionResp resp) { + verify(!doneReading); + responses.add(Result.createWithResponse(resp)); + } + + /** + * Places a {@code Throwable} in the {@code responses} queue. + * + * @param t the {@code Throwable} caught when reading the stream to the S2A handshaker module. + */ + @Override + public void onError(Throwable t) { + responses.add(Result.createWithThrowable(t)); + } + + /** + * Sets {@code doneReading} to true, and places a {@code ConnectionClosedException} in the + * {@code responses} queue. + */ + @Override + public void onCompleted() { + logger.log(Level.INFO, "Reading from the S2A is complete."); + doneReading = true; + responses.add( + Result.createWithThrowable( + new ConnectionClosedException("Reading from the S2A is complete."))); + } + } + + private static final class Result { + private final Optional response; + private final Optional throwable; + + static Result createWithResponse(SessionResp response) { + return new Result(Optional.of(response), Optional.empty()); + } + + static Result createWithThrowable(Throwable throwable) { + return new Result(Optional.empty(), Optional.of(throwable)); + } + + private Result(Optional response, Optional throwable) { + checkArgument(response.isPresent() != throwable.isPresent()); + this.response = response; + this.throwable = throwable; + } + + /** Throws {@code throwable} if present, and returns {@code response} otherwise. */ + SessionResp getResultOrThrow() throws IOException { + if (throwable.isPresent()) { + if (throwable.get() instanceof ConnectionClosedException) { + ConnectionClosedException exception = (ConnectionClosedException) throwable.get(); + throw exception; + } else { + throw new IOException(throwable.get()); + } + } + verify(response.isPresent()); + return response.get(); + } + } +} diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2ATrustManager.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2ATrustManager.java new file mode 100644 index 00000000000..a7ffafd01f2 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2ATrustManager.java @@ -0,0 +1,159 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainReq; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainReq.VerificationMode; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainResp; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import java.io.IOException; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Optional; +import javax.annotation.concurrent.NotThreadSafe; +import javax.net.ssl.X509TrustManager; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** Offloads verification of the peer certificate chain to S2A. */ +@NotThreadSafe +final class S2ATrustManager implements X509TrustManager { + private final Optional localIdentity; + private final S2AStub stub; + private final String hostname; + + static S2ATrustManager createForClient( + S2AStub stub, String hostname, Optional localIdentity) { + checkNotNull(stub); + checkNotNull(hostname); + return new S2ATrustManager(stub, hostname, localIdentity); + } + + private S2ATrustManager(S2AStub stub, String hostname, Optional localIdentity) { + this.stub = stub; + this.hostname = hostname; + this.localIdentity = localIdentity; + } + + /** + * Validates the given certificate chain provided by the peer. + * + * @param chain the peer certificate chain + * @param authType the authentication type based on the client certificate + * @throws IllegalArgumentException if null or zero-length chain is passed in for the chain + * parameter. + * @throws CertificateException if the certificate chain is not trusted by this TrustManager. + */ + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + checkPeerTrusted(chain, /* isCheckingClientCertificateChain= */ true); + } + + /** + * Validates the given certificate chain provided by the peer. + * + * @param chain the peer certificate chain + * @param authType the authentication type based on the client certificate + * @throws IllegalArgumentException if null or zero-length chain is passed in for the chain + * parameter. + * @throws CertificateException if the certificate chain is not trusted by this TrustManager. + */ + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + checkPeerTrusted(chain, /* isCheckingClientCertificateChain= */ false); + } + + /** + * Returns null because the accepted issuers are held in S2A and this class receives decision made + * from S2A on the fly about which to use to verify a given chain. + * + * @return null. + */ + @Override + public X509Certificate @Nullable [] getAcceptedIssuers() { + return null; + } + + private void checkPeerTrusted(X509Certificate[] chain, boolean isCheckingClientCertificateChain) + throws CertificateException { + checkNotNull(chain); + checkArgument(chain.length > 0, "Certificate chain has zero certificates."); + + ValidatePeerCertificateChainReq.Builder validatePeerCertificateChainReq = + ValidatePeerCertificateChainReq.newBuilder().setMode(VerificationMode.UNSPECIFIED); + if (isCheckingClientCertificateChain) { + validatePeerCertificateChainReq.setClientPeer( + ValidatePeerCertificateChainReq.ClientPeer.newBuilder() + .addAllCertificateChain(certificateChainToDerChain(chain))); + } else { + validatePeerCertificateChainReq.setServerPeer( + ValidatePeerCertificateChainReq.ServerPeer.newBuilder() + .addAllCertificateChain(certificateChainToDerChain(chain)) + .setServerHostname(hostname)); + } + + SessionReq.Builder reqBuilder = + SessionReq.newBuilder().setValidatePeerCertificateChainReq(validatePeerCertificateChainReq); + if (localIdentity.isPresent()) { + reqBuilder.setLocalIdentity(localIdentity.get().getIdentity()); + } + + SessionResp resp; + try { + resp = stub.send(reqBuilder.build()); + } catch (IOException e) { + throw new CertificateException("Failed to send request to S2A.", e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new CertificateException("Failed to send request to S2A.", e); + } + if (resp.hasStatus() && resp.getStatus().getCode() != 0) { + throw new CertificateException( + String.format( + "Error occurred in response from S2A, error code: %d, error message: %s.", + resp.getStatus().getCode(), resp.getStatus().getDetails())); + } + + if (!resp.hasValidatePeerCertificateChainResp()) { + throw new CertificateException("No valid response received from S2A."); + } + + ValidatePeerCertificateChainResp validationResult = resp.getValidatePeerCertificateChainResp(); + if (validationResult.getValidationResult() + != ValidatePeerCertificateChainResp.ValidationResult.SUCCESS) { + throw new CertificateException(validationResult.getValidationDetails()); + } + } + + private static ImmutableList certificateChainToDerChain(X509Certificate[] chain) + throws CertificateEncodingException { + ImmutableList.Builder derChain = ImmutableList.builder(); + for (X509Certificate certificate : chain) { + derChain.add(ByteString.copyFrom(certificate.getEncoded())); + } + return derChain.build(); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java new file mode 100644 index 00000000000..5d4ef9eb667 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java @@ -0,0 +1,187 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.base.Preconditions.checkNotNull; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.collect.ImmutableSet; +import com.google.s2a.proto.v2.AuthenticationMechanism; +import com.google.s2a.proto.v2.ConnectionSide; +import com.google.s2a.proto.v2.GetTlsConfigurationReq; +import com.google.s2a.proto.v2.GetTlsConfigurationResp; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslContextOption; +import io.netty.handler.ssl.OpenSslSessionContext; +import io.netty.handler.ssl.OpenSslX509KeyManagerFactory; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Optional; +import javax.net.ssl.KeyManager; +import javax.net.ssl.SSLSessionContext; + +/** Creates {@link SslContext} objects with TLS configurations from S2A server. */ +final class SslContextFactory { + + /** + * Creates {@link SslContext} objects for client with TLS configurations from S2A server. + * + * @param stub the {@link S2AStub} to talk to the S2A server. + * @param targetName the {@link String} of the server that this client makes connection to. + * @param localIdentity the {@link S2AIdentity} that should be used when talking to S2A server. + * Will use default identity if empty. + * @return a {@link SslContext} object. + * @throws NullPointerException if either {@code stub} or {@code targetName} is null. + * @throws IOException if an unexpected response from S2A server is received. + * @throws InterruptedException if {@code stub} is closed. + */ + static SslContext createForClient( + S2AStub stub, String targetName, Optional localIdentity) + throws IOException, + InterruptedException, + CertificateException, + KeyStoreException, + NoSuchAlgorithmException, + UnrecoverableKeyException, + GeneralSecurityException { + checkNotNull(stub, "stub should not be null."); + checkNotNull(targetName, "targetName should not be null on client side."); + GetTlsConfigurationResp.ClientTlsConfiguration clientTlsConfiguration; + try { + clientTlsConfiguration = getClientTlsConfigurationFromS2A(stub, localIdentity); + } catch (IOException | InterruptedException e) { + throw new GeneralSecurityException("Failed to get client TLS configuration from S2A.", e); + } + + // Use the default value for timeout. + // Use the smallest possible value for cache size. + // The Provider is by default OPENSSL. No need to manually set it. + SslContextBuilder sslContextBuilder = + GrpcSslContexts.configure(SslContextBuilder.forClient()) + .sessionCacheSize(1) + .sessionTimeout(0); + + configureSslContextWithClientTlsConfiguration(clientTlsConfiguration, sslContextBuilder); + sslContextBuilder.trustManager( + S2ATrustManager.createForClient(stub, targetName, localIdentity)); + sslContextBuilder.option( + OpenSslContextOption.PRIVATE_KEY_METHOD, S2APrivateKeyMethod.create(stub, localIdentity)); + + SslContext sslContext = sslContextBuilder.build(); + SSLSessionContext sslSessionContext = sslContext.sessionContext(); + if (sslSessionContext instanceof OpenSslSessionContext) { + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + openSslSessionContext.setSessionCacheEnabled(false); + } + + return sslContext; + } + + private static GetTlsConfigurationResp.ClientTlsConfiguration getClientTlsConfigurationFromS2A( + S2AStub stub, Optional localIdentity) throws IOException, InterruptedException { + checkNotNull(stub, "stub should not be null."); + SessionReq.Builder reqBuilder = SessionReq.newBuilder(); + if (localIdentity.isPresent()) { + reqBuilder.setLocalIdentity(localIdentity.get().getIdentity()); + } + Optional authMechanism = + GetAuthenticationMechanisms.getAuthMechanism(localIdentity, + GetAuthenticationMechanisms.TOKEN_MANAGER); + if (authMechanism.isPresent()) { + reqBuilder.addAuthenticationMechanisms(authMechanism.get()); + } + SessionResp resp = + stub.send( + reqBuilder + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build()); + if (resp.hasStatus() && resp.getStatus().getCode() != 0) { + throw new S2AConnectionException( + String.format( + "response from S2A server has ean error %d with error message %s.", + resp.getStatus().getCode(), resp.getStatus().getDetails())); + } + if (!resp.getGetTlsConfigurationResp().hasClientTlsConfiguration()) { + throw new S2AConnectionException( + "Response from S2A server does NOT contain ClientTlsConfiguration."); + } + return resp.getGetTlsConfigurationResp().getClientTlsConfiguration(); + } + + private static void configureSslContextWithClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration clientTlsConfiguration, + SslContextBuilder sslContextBuilder) + throws CertificateException, + IOException, + KeyStoreException, + NoSuchAlgorithmException, + UnrecoverableKeyException { + sslContextBuilder.keyManager(createKeylessManager(clientTlsConfiguration)); + ImmutableSet tlsVersions; + tlsVersions = + ProtoUtil.buildTlsProtocolVersionSet( + clientTlsConfiguration.getMinTlsVersion(), clientTlsConfiguration.getMaxTlsVersion()); + if (tlsVersions.isEmpty()) { + throw new S2AConnectionException( + "Set of TLS versions received from S2A server is empty or not supported."); + } + sslContextBuilder.protocols(tlsVersions); + } + + private static KeyManager createKeylessManager( + GetTlsConfigurationResp.ClientTlsConfiguration clientTlsConfiguration) + throws CertificateException, + IOException, + KeyStoreException, + NoSuchAlgorithmException, + UnrecoverableKeyException { + X509Certificate[] certificates = + new X509Certificate[clientTlsConfiguration.getCertificateChainCount()]; + for (int i = 0; i < clientTlsConfiguration.getCertificateChainCount(); ++i) { + certificates[i] = convertStringToX509Cert(clientTlsConfiguration.getCertificateChain(i)); + } + KeyManager[] keyManagers = + OpenSslX509KeyManagerFactory.newKeyless(certificates).getKeyManagers(); + if (keyManagers == null || keyManagers.length == 0) { + throw new IllegalStateException("No key managers created."); + } + return keyManagers[0]; + } + + private static X509Certificate convertStringToX509Cert(String certificate) + throws CertificateException { + return (X509Certificate) + CertificateFactory.getInstance("X509") + .generateCertificate(new ByteArrayInputStream(certificate.getBytes(UTF_8))); + } + + private SslContextFactory() {} +} diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/AccessTokenManager.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/AccessTokenManager.java new file mode 100644 index 00000000000..65fca46bbb2 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/AccessTokenManager.java @@ -0,0 +1,49 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker.tokenmanager; + +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import java.util.Optional; +import javax.annotation.concurrent.ThreadSafe; + +/** Manages access tokens for authenticating to the S2A. */ +@ThreadSafe +public final class AccessTokenManager { + private final TokenFetcher tokenFetcher; + + /** Creates an {@code AccessTokenManager} based on the environment where the application runs. */ + public static Optional create() { + Optional tokenFetcher = SingleTokenFetcher.create(); + return tokenFetcher.isPresent() + ? Optional.of(new AccessTokenManager(tokenFetcher.get())) + : Optional.empty(); + } + + private AccessTokenManager(TokenFetcher tokenFetcher) { + this.tokenFetcher = tokenFetcher; + } + + /** Returns an access token when no identity is specified. */ + public String getDefaultToken() { + return tokenFetcher.getDefaultToken(); + } + + /** Returns an access token for the given identity. */ + public String getToken(S2AIdentity identity) { + return tokenFetcher.getToken(identity); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/SingleTokenFetcher.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/SingleTokenFetcher.java new file mode 100644 index 00000000000..28aa0f87ba1 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/SingleTokenFetcher.java @@ -0,0 +1,62 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker.tokenmanager; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import java.util.Optional; + +/** Fetches a single access token via an environment variable. */ +@SuppressWarnings("NonFinalStaticField") +public final class SingleTokenFetcher implements TokenFetcher { + private static final String ENVIRONMENT_VARIABLE = "S2A_ACCESS_TOKEN"; + private static String accessToken = System.getenv(ENVIRONMENT_VARIABLE); + + private final String token; + + /** + * Creates a {@code SingleTokenFetcher} from {@code ENVIRONMENT_VARIABLE}, and returns an empty + * {@code Optional} instance if the token could not be fetched. + */ + public static Optional create() { + return Optional.ofNullable(accessToken).map(SingleTokenFetcher::new); + } + + @VisibleForTesting + public static void setAccessToken(String token) { + accessToken = token; + } + + @VisibleForTesting + public static String getAccessToken() { + return accessToken; + } + + private SingleTokenFetcher(String token) { + this.token = token; + } + + @Override + public String getDefaultToken() { + return token; + } + + @Override + public String getToken(S2AIdentity identity) { + return token; + } +} diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/TokenFetcher.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/TokenFetcher.java new file mode 100644 index 00000000000..6827f095afe --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/TokenFetcher.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker.tokenmanager; + +import io.grpc.s2a.internal.handshaker.S2AIdentity; + +/** Fetches tokens used to authenticate to S2A. */ +interface TokenFetcher { + /** Returns an access token when no identity is specified. */ + String getDefaultToken(); + + /** Returns an access token for the given identity. */ + String getToken(S2AIdentity identity); +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/IntegrationTest.java b/s2a/src/test/java/io/grpc/s2a/IntegrationTest.java new file mode 100644 index 00000000000..1d3568808c6 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/IntegrationTest.java @@ -0,0 +1,256 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.truth.Truth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.google.s2a.proto.v2.S2AServiceGrpc; +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; +import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.benchmarks.Utils; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.s2a.S2AChannelCredentials; +import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel; +import io.grpc.s2a.internal.handshaker.FakeS2AServer; +import io.grpc.s2a.internal.handshaker.S2AStub; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.netty.handler.ssl.ClientAuth; +import io.netty.handler.ssl.OpenSslSessionContext; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; +import java.io.InputStream; +import java.util.concurrent.FutureTask; +import java.util.logging.Logger; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSessionContext; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class IntegrationTest { + private static final Logger logger = Logger.getLogger(FakeS2AServer.class.getName()); + private String s2aAddress; + private Server s2aServer; + private String s2aDelayAddress; + private Server s2aDelayServer; + private String mtlsS2AAddress; + private Server mtlsS2AServer; + private String serverAddress; + private Server server; + + @Before + public void setUp() throws Exception { + s2aServer = ServerBuilder.forPort(0).addService(new FakeS2AServer()).build().start(); + int s2aPort = s2aServer.getPort(); + s2aAddress = "localhost:" + s2aPort; + logger.info("S2A service listening on localhost:" + s2aPort); + ClassLoader classLoader = IntegrationTest.class.getClassLoader(); + InputStream s2aCert = classLoader.getResourceAsStream("server_cert.pem"); + InputStream s2aKey = classLoader.getResourceAsStream("server_key.pem"); + InputStream rootCert = classLoader.getResourceAsStream("root_cert.pem"); + ServerCredentials s2aCreds = + TlsServerCredentials.newBuilder() + .keyManager(s2aCert, s2aKey) + .trustManager(rootCert) + .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) + .build(); + mtlsS2AServer = NettyServerBuilder.forPort(0, s2aCreds).addService(new FakeS2AServer()).build(); + mtlsS2AServer.start(); + int mtlsS2APort = mtlsS2AServer.getPort(); + mtlsS2AAddress = "localhost:" + mtlsS2APort; + logger.info("mTLS S2A service listening on localhost:" + mtlsS2APort); + + int s2aDelayPort = Utils.pickUnusedPort(); + s2aDelayAddress = "localhost:" + s2aDelayPort; + s2aDelayServer = ServerBuilder.forPort(s2aDelayPort).addService(new FakeS2AServer()).build(); + + server = + NettyServerBuilder.forPort(0) + .addService(new SimpleServiceImpl()) + .sslContext(buildSslContext()) + .build() + .start(); + int serverPort = server.getPort(); + serverAddress = "localhost:" + serverPort; + logger.info("Simple Service listening on localhost:" + serverPort); + } + + @After + public void tearDown() throws Exception { + server.shutdown(); + s2aServer.shutdown(); + s2aDelayServer.shutdown(); + mtlsS2AServer.shutdown(); + + server.awaitTermination(10, SECONDS); + s2aServer.awaitTermination(10, SECONDS); + s2aDelayServer.awaitTermination(10, SECONDS); + mtlsS2AServer.awaitTermination(10, SECONDS); + } + + @Test + public void clientCommunicateUsingS2ACredentials_succeeds() throws Exception { + ChannelCredentials credentials = + S2AChannelCredentials.newBuilder(s2aAddress, InsecureChannelCredentials.create()) + .setLocalSpiffeId("test-spiffe-id").build(); + ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); + + assertThat(doUnaryRpc(channel)).isTrue(); + } + + @Test + public void clientCommunicateUsingS2ACredentialsNoLocalIdentity_succeeds() throws Exception { + ChannelCredentials credentials = S2AChannelCredentials.newBuilder(s2aAddress, + InsecureChannelCredentials.create()).build(); + ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); + + assertThat(doUnaryRpc(channel)).isTrue(); + } + + @Test + public void clientCommunicateUsingS2ACredentialsSucceeds_verifyStreamToS2AClosed() + throws Exception { + ObjectPool s2aChannelPool = + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource(s2aAddress, + InsecureChannelCredentials.create())); + Channel ch = s2aChannelPool.getObject(); + S2AStub stub = S2AStub.newInstance(S2AServiceGrpc.newStub(ch)); + ChannelCredentials credentials = + S2AChannelCredentials.newBuilder(s2aAddress, InsecureChannelCredentials.create()) + .setLocalSpiffeId("test-spiffe-id").setStub(stub).build(); + ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); + + s2aChannelPool.returnObject(ch); + assertThat(doUnaryRpc(channel)).isTrue(); + assertThat(stub.isClosed()).isTrue(); + } + + @Test + public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Exception { + ClassLoader classLoader = IntegrationTest.class.getClassLoader(); + InputStream privateKey = classLoader.getResourceAsStream("client_key.pem"); + InputStream certChain = classLoader.getResourceAsStream("client_cert.pem"); + InputStream trustBundle = classLoader.getResourceAsStream("root_cert.pem"); + ChannelCredentials s2aChannelCredentials = + TlsChannelCredentials.newBuilder() + .keyManager(certChain, privateKey) + .trustManager(trustBundle) + .build(); + + ChannelCredentials credentials = + S2AChannelCredentials.newBuilder(mtlsS2AAddress, s2aChannelCredentials) + .setLocalSpiffeId("test-spiffe-id") + .build(); + ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); + + assertThat(doUnaryRpc(channel)).isTrue(); + } + + @Test + public void clientCommunicateUsingS2ACredentials_s2AdelayStart_succeeds() throws Exception { + ChannelCredentials credentials = S2AChannelCredentials.newBuilder(s2aDelayAddress, + InsecureChannelCredentials.create()).build(); + ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); + + FutureTask rpc = new FutureTask<>(() -> doUnaryRpc(channel)); + new Thread(rpc).start(); + Thread.sleep(2000); + s2aDelayServer.start(); + assertThat(rpc.get()).isTrue(); + } + + public static boolean doUnaryRpc(ManagedChannel channel) throws InterruptedException { + try { + SimpleServiceGrpc.SimpleServiceBlockingStub stub = + SimpleServiceGrpc.newBlockingStub(channel); + SimpleResponse resp = stub.unaryRpc(SimpleRequest.newBuilder() + .setRequestMessage("S2A team") + .build()); + if (!resp.getResponseMessage().equals("Hello, S2A team!")) { + logger.info( + "Received unexpected message from the Simple Service: " + resp.getResponseMessage()); + throw new RuntimeException(); + } else { + System.out.println( + "We received this message from the Simple Service: " + resp.getResponseMessage()); + return true; + } + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + } + } + + private static SslContext buildSslContext() throws SSLException { + ClassLoader classLoader = IntegrationTest.class.getClassLoader(); + InputStream privateKey = classLoader.getResourceAsStream("leaf_key_ec.pem"); + InputStream rootCert = classLoader.getResourceAsStream("root_cert_ec.pem"); + InputStream certChain = classLoader.getResourceAsStream("cert_chain_ec.pem"); + SslContextBuilder sslServerContextBuilder = + SslContextBuilder.forServer(certChain, privateKey); + SslContext sslServerContext = + GrpcSslContexts.configure(sslServerContextBuilder, SslProvider.OPENSSL) + .protocols("TLSv1.3", "TLSv1.2") + .trustManager(rootCert) + .clientAuth(ClientAuth.REQUIRE) + .build(); + + // Enable TLS resumption. This requires using the OpenSSL provider, since the JDK provider does + // not allow a server to send session tickets. + SSLSessionContext sslSessionContext = sslServerContext.sessionContext(); + if (!(sslSessionContext instanceof OpenSslSessionContext)) { + throw new SSLException("sslSessionContext does not use OpenSSL."); + } + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + // Calling {@code setTicketKeys} without specifying any keys means that the SSL libraries will + // handle the generation of the resumption master secret. + openSslSessionContext.setTicketKeys(); + + return sslServerContext; + } + + public static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest request, StreamObserver observer) { + observer.onNext( + SimpleResponse.newBuilder() + .setResponseMessage("Hello, " + request.getRequestMessage() + "!") + .build()); + observer.onCompleted(); + } + } +} diff --git a/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java b/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java new file mode 100644 index 00000000000..3e6eef7f470 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java @@ -0,0 +1,136 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.grpc.ChannelCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.TlsChannelCredentials; +import java.io.InputStream; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@code S2AChannelCredentials}. */ +@RunWith(JUnit4.class) +public final class S2AChannelCredentialsTest { + @Test + public void newBuilder_nullAddress_throwsException() throws Exception { + assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.newBuilder(null, + InsecureChannelCredentials.create())); + } + + @Test + public void newBuilder_emptyAddress_throwsException() throws Exception { + assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.newBuilder("", + InsecureChannelCredentials.create())); + } + + @Test + public void newBuilder_nullChannelCreds_throwsException() throws Exception { + assertThrows(NullPointerException.class, () -> S2AChannelCredentials + .newBuilder("s2a_address", null)); + } + + @Test + public void setLocalSpiffeId_nullArgument_throwsException() throws Exception { + assertThrows( + NullPointerException.class, + () -> S2AChannelCredentials.newBuilder("s2a_address", + InsecureChannelCredentials.create()).setLocalSpiffeId(null)); + } + + @Test + public void setLocalHostname_nullArgument_throwsException() throws Exception { + assertThrows( + NullPointerException.class, + () -> S2AChannelCredentials.newBuilder("s2a_address", + InsecureChannelCredentials.create()).setLocalHostname(null)); + } + + @Test + public void setLocalUid_nullArgument_throwsException() throws Exception { + assertThrows( + NullPointerException.class, + () -> S2AChannelCredentials.newBuilder("s2a_address", + InsecureChannelCredentials.create()).setLocalUid(null)); + } + + @Test + public void build_withLocalSpiffeId_succeeds() throws Exception { + assertThat( + S2AChannelCredentials.newBuilder("s2a_address", InsecureChannelCredentials.create()) + .setLocalSpiffeId("spiffe://test") + .build()) + .isNotNull(); + } + + @Test + public void build_withLocalHostname_succeeds() throws Exception { + assertThat( + S2AChannelCredentials.newBuilder("s2a_address", InsecureChannelCredentials.create()) + .setLocalHostname("local_hostname") + .build()) + .isNotNull(); + } + + @Test + public void build_withLocalUid_succeeds() throws Exception { + assertThat(S2AChannelCredentials.newBuilder("s2a_address", + InsecureChannelCredentials.create()).setLocalUid("local_uid").build()) + .isNotNull(); + } + + @Test + public void build_withNoLocalIdentity_succeeds() throws Exception { + assertThat(S2AChannelCredentials.newBuilder("s2a_address", + InsecureChannelCredentials.create()).build()) + .isNotNull(); + } + + @Test + public void build_withUseMtlsToS2ANoLocalIdentity_success() throws Exception { + ChannelCredentials s2aChannelCredentials = getTlsChannelCredentials(); + assertThat( + S2AChannelCredentials.newBuilder("s2a_address", s2aChannelCredentials) + .build()) + .isNotNull(); + } + + @Test + public void build_withUseMtlsToS2AWithLocalUid_success() throws Exception { + ChannelCredentials s2aChannelCredentials = getTlsChannelCredentials(); + assertThat( + S2AChannelCredentials.newBuilder("s2a_address", s2aChannelCredentials) + .setLocalUid("local_uid") + .build()) + .isNotNull(); + } + + private static ChannelCredentials getTlsChannelCredentials() throws Exception { + ClassLoader classLoader = S2AChannelCredentialsTest.class.getClassLoader(); + InputStream privateKey = classLoader.getResourceAsStream("client_key.pem"); + InputStream certChain = classLoader.getResourceAsStream("client_cert.pem"); + InputStream trustBundle = classLoader.getResourceAsStream("root_cert.pem"); + return TlsChannelCredentials.newBuilder() + .keyManager(certChain, privateKey) + .trustManager(trustBundle) + .build(); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/channel/S2AHandshakerServiceChannelTest.java b/s2a/src/test/java/io/grpc/s2a/internal/channel/S2AHandshakerServiceChannelTest.java new file mode 100644 index 00000000000..9ba3caaf99e --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/channel/S2AHandshakerServiceChannelTest.java @@ -0,0 +1,259 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.channel; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; +import io.grpc.StatusRuntimeException; +import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import java.io.InputStream; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AHandshakerServiceChannel}. */ +@RunWith(JUnit4.class) +public final class S2AHandshakerServiceChannelTest { + @ClassRule public static final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private Server mtlsServer; + private Server plaintextServer; + + @Before + public void setUp() throws Exception { + mtlsServer = createMtlsServer(); + plaintextServer = createPlaintextServer(); + mtlsServer.start(); + plaintextServer.start(); + } + + /** + * Creates a {@code Resource} and verifies that it produces a {@code ChannelResource} + * instance by using its {@code toString()} method. + */ + @Test + public void getChannelResource_success() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + InsecureChannelCredentials.create()); + assertThat(resource.toString()).isEqualTo("grpc-s2a-channel"); + } + + /** Same as getChannelResource_success, but use mTLS. */ + @Test + public void getChannelResource_mtlsSuccess() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + assertThat(resource.toString()).isEqualTo("grpc-s2a-channel"); + } + + /** + * Creates two {@code Resoure}s for the same target address and verifies that they are + * distinct. + */ + @Test + public void getChannelResource_twoUnEqualChannels() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + InsecureChannelCredentials.create()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + InsecureChannelCredentials.create()); + assertThat(resource).isNotEqualTo(resourceTwo); + } + + /** Same as getChannelResource_twoUnEqualChannels, but use mTLS. */ + @Test + public void getChannelResource_mtlsTwoUnEqualChannels() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + assertThat(resource).isNotEqualTo(resourceTwo); + } + + /** + * Creates two {@code Resoure}s for different target addresses and verifies that they are + * distinct. + */ + @Test + public void getChannelResource_twoDistinctChannels() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + InsecureChannelCredentials.create()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort() + 1, InsecureChannelCredentials.create()); + assertThat(resourceTwo).isNotEqualTo(resource); + } + + /** Same as getChannelResource_twoDistinctChannels, but use mTLS. */ + @Test + public void getChannelResource_mtlsTwoDistinctChannels() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort() + 1, getTlsChannelCredentials()); + assertThat(resourceTwo).isNotEqualTo(resource); + } + + /** + * Uses a {@code Resource} to create a channel, closes the channel, and verifies that the + * channel is closed by attempting to make a simple RPC. + */ + @Test + public void close_success() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + InsecureChannelCredentials.create()); + Channel channel = resource.create(); + resource.close(channel); + StatusRuntimeException expected = + assertThrows( + StatusRuntimeException.class, + () -> + SimpleServiceGrpc.newBlockingStub(channel) + .unaryRpc(SimpleRequest.getDefaultInstance())); + assertThat(expected).hasMessageThat().isEqualTo("UNAVAILABLE: Channel shutdown invoked"); + } + + /** Same as close_success, but use mTLS. */ + @Test + public void close_mtlsSuccess() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Channel channel = resource.create(); + resource.close(channel); + StatusRuntimeException expected = + assertThrows( + StatusRuntimeException.class, + () -> + SimpleServiceGrpc.newBlockingStub(channel) + .unaryRpc(SimpleRequest.getDefaultInstance())); + assertThat(expected).hasMessageThat().isEqualTo("UNAVAILABLE: Channel shutdown invoked"); + } + + /** + * Creates and closes a {@code ManagedChannel}, creates a new channel from the same + * resource, and verifies that this second channel is useable. + */ + @Test + public void create_succeedsAfterCloseIsCalledOnce() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + InsecureChannelCredentials.create()); + Channel channelOne = resource.create(); + resource.close(channelOne); + + Channel channelTwo = resource.create(); + assertThat(channelTwo).isInstanceOf(ManagedChannel.class); + assertThat( + SimpleServiceGrpc.newBlockingStub(channelTwo) + .unaryRpc(SimpleRequest.getDefaultInstance())) + .isEqualToDefaultInstance(); + resource.close(channelTwo); + } + + /** Same as create_succeedsAfterCloseIsCalledOnce, but use mTLS. */ + @Test + public void create_mtlsSucceedsAfterCloseIsCalledOnce() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Channel channelOne = resource.create(); + resource.close(channelOne); + + Channel channelTwo = resource.create(); + assertThat(channelTwo).isInstanceOf(ManagedChannel.class); + assertThat( + SimpleServiceGrpc.newBlockingStub(channelTwo) + .unaryRpc(SimpleRequest.getDefaultInstance())) + .isEqualToDefaultInstance(); + resource.close(channelTwo); + } + + private static Server createMtlsServer() throws Exception { + SimpleServiceImpl service = new SimpleServiceImpl(); + ClassLoader classLoader = S2AHandshakerServiceChannelTest.class.getClassLoader(); + InputStream serverCert = classLoader.getResourceAsStream("server_cert.pem"); + InputStream serverKey = classLoader.getResourceAsStream("server_key.pem"); + InputStream rootCert = classLoader.getResourceAsStream("root_cert.pem"); + ServerCredentials creds = + TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverKey) + .trustManager(rootCert) + .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) + .build(); + return grpcCleanup.register( + NettyServerBuilder.forPort(0, creds).addService(service).build()); + } + + private static Server createPlaintextServer() { + SimpleServiceImpl service = new SimpleServiceImpl(); + return grpcCleanup.register( + ServerBuilder.forPort(0).addService(service).build()); + } + + private static ChannelCredentials getTlsChannelCredentials() throws Exception { + ClassLoader classLoader = S2AHandshakerServiceChannelTest.class.getClassLoader(); + InputStream clientCert = classLoader.getResourceAsStream("client_cert.pem"); + InputStream clientKey = classLoader.getResourceAsStream("client_key.pem"); + InputStream rootCert = classLoader.getResourceAsStream("root_cert.pem"); + return TlsChannelCredentials.newBuilder() + .keyManager(clientCert, clientKey) + .trustManager(rootCert) + .build(); + } + + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest request, StreamObserver streamObserver) { + streamObserver.onNext(SimpleResponse.getDefaultInstance()); + streamObserver.onCompleted(); + } + } +} diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeS2AServer.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeS2AServer.java new file mode 100644 index 00000000000..322397c93be --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeS2AServer.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import com.google.s2a.proto.v2.S2AServiceGrpc; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.security.NoSuchAlgorithmException; +import java.security.spec.InvalidKeySpecException; +import java.util.logging.Logger; + +/** A fake S2Av2 server that should be used for testing only. */ +public final class FakeS2AServer extends S2AServiceGrpc.S2AServiceImplBase { + private static final Logger logger = Logger.getLogger(FakeS2AServer.class.getName()); + + private final FakeWriter writer; + + public FakeS2AServer() throws InvalidKeySpecException, NoSuchAlgorithmException, IOException { + this.writer = new FakeWriter(); + this.writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS).initializePrivateKey(); + } + + @Override + public StreamObserver setUpSession(StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(SessionReq req) { + logger.info("Received a request from client."); + try { + responseObserver.onNext(writer.handleResponse(req)); + } catch (IOException e) { + responseObserver.onError(e); + } + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeS2AServerTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeS2AServerTest.java new file mode 100644 index 00000000000..c3155b864b3 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeS2AServerTest.java @@ -0,0 +1,300 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.SettableFuture; +import com.google.protobuf.ByteString; +import com.google.s2a.proto.v2.Ciphersuite; +import com.google.s2a.proto.v2.ConnectionSide; +import com.google.s2a.proto.v2.GetTlsConfigurationReq; +import com.google.s2a.proto.v2.GetTlsConfigurationResp; +import com.google.s2a.proto.v2.S2AServiceGrpc; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import com.google.s2a.proto.v2.TLSVersion; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainReq; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainReq.VerificationMode; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainResp; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeoutException; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link FakeS2AServer}. */ +@RunWith(JUnit4.class) +public final class FakeS2AServerTest { + private static final Logger logger = Logger.getLogger(FakeS2AServerTest.class.getName()); + + private static final ImmutableList FAKE_CERT_DER_CHAIN = + ImmutableList.of(ByteString.copyFrom("fake-der-chain".getBytes(StandardCharsets.US_ASCII))); + private String serverAddress; + private Server fakeS2AServer; + + @Before + public void setUp() throws Exception { + fakeS2AServer = ServerBuilder.forPort(0).addService(new FakeS2AServer()).build(); + fakeS2AServer.start(); + serverAddress = String.format("localhost:%d", fakeS2AServer.getPort()); + } + + @After + public void tearDown() throws Exception { + fakeS2AServer.shutdown(); + fakeS2AServer.awaitTermination(10, SECONDS); + } + + @Test + public void callS2AServerOnce_getTlsConfiguration_returnsValidResult() + throws InterruptedException, + IOException, + java.util.concurrent.ExecutionException, + TimeoutException { + ExecutorService executor = Executors.newSingleThreadExecutor(); + logger.info("Client connecting to: " + serverAddress); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + SettableFuture respFuture = SettableFuture.create(); + try { + S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); + StreamObserver requestObserver = + asyncStub.setUpSession( + new StreamObserver() { + SessionResp recvResp; + @Override + public void onNext(SessionResp resp) { + recvResp = resp; + } + + @Override + public void onError(Throwable t) { + respFuture.setException(t); + } + + @Override + public void onCompleted() { + respFuture.set(recvResp); + } + }); + try { + requestObserver.onNext( + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build()); + } catch (RuntimeException e) { + // Cancel the RPC. + requestObserver.onError(e); + throw e; + } + // Mark the end of requests. + requestObserver.onCompleted(); + // Wait for receiving to happen. + respFuture.get(5, SECONDS); + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + + String leafCertString = ""; + String cert2String = ""; + String cert1String = ""; + ClassLoader classLoader = FakeS2AServerTest.class.getClassLoader(); + try ( + InputStream leafCert = classLoader.getResourceAsStream("leaf_cert_ec.pem"); + InputStream cert2 = classLoader.getResourceAsStream("int_cert2_ec.pem"); + InputStream cert1 = classLoader.getResourceAsStream("int_cert1_ec.pem"); + ) { + leafCertString = FakeWriter.convertInputStreamToString(leafCert); + cert2String = FakeWriter.convertInputStreamToString(cert2); + cert1String = FakeWriter.convertInputStreamToString(cert1); + } + + SessionResp expected = + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(leafCertString) + .addCertificateChain(cert1String) + .addCertificateChain(cert2String) + .setMinTlsVersion(TLSVersion.TLS_VERSION_1_3) + .setMaxTlsVersion(TLSVersion.TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + assertThat(respFuture.get()).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void callS2AServerOnce_validatePeerCertifiate_returnsValidResult() + throws InterruptedException, java.util.concurrent.ExecutionException, TimeoutException { + ExecutorService executor = Executors.newSingleThreadExecutor(); + logger.info("Client connecting to: " + serverAddress); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + SettableFuture respFuture = SettableFuture.create(); + try { + S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); + StreamObserver requestObserver = + asyncStub.setUpSession( + new StreamObserver() { + private SessionResp recvResp; + @Override + public void onNext(SessionResp resp) { + recvResp = resp; + } + + @Override + public void onError(Throwable t) { + respFuture.setException(t); + } + + @Override + public void onCompleted() { + respFuture.set(recvResp); + } + }); + try { + requestObserver.onNext( + SessionReq.newBuilder() + .setValidatePeerCertificateChainReq( + ValidatePeerCertificateChainReq.newBuilder() + .setMode(VerificationMode.UNSPECIFIED) + .setClientPeer( + ValidatePeerCertificateChainReq.ClientPeer.newBuilder() + .addAllCertificateChain(FAKE_CERT_DER_CHAIN))) + .build()); + } catch (RuntimeException e) { + // Cancel the RPC. + requestObserver.onError(e); + throw e; + } + // Mark the end of requests. + requestObserver.onCompleted(); + // Wait for receiving to happen. + respFuture.get(5, SECONDS); + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + + SessionResp expected = + SessionResp.newBuilder() + .setValidatePeerCertificateChainResp( + ValidatePeerCertificateChainResp.newBuilder() + .setValidationResult(ValidatePeerCertificateChainResp.ValidationResult.SUCCESS)) + .build(); + assertThat(respFuture.get()).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void callS2AServerRepeatedly_returnsValidResult() throws InterruptedException { + final int numberOfRequests = 10; + ExecutorService executor = Executors.newSingleThreadExecutor(); + logger.info("Client connecting to: " + serverAddress); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + + try { + S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); + CountDownLatch finishLatch = new CountDownLatch(1); + StreamObserver requestObserver = + asyncStub.setUpSession( + new StreamObserver() { + private int expectedNumberOfReplies = numberOfRequests; + + @Override + public void onNext(SessionResp reply) { + System.out.println("Received a message from the S2AService service."); + expectedNumberOfReplies -= 1; + } + + @Override + public void onError(Throwable t) { + finishLatch.countDown(); + if (expectedNumberOfReplies != 0) { + throw new RuntimeException(t); + } + } + + @Override + public void onCompleted() { + finishLatch.countDown(); + if (expectedNumberOfReplies != 0) { + throw new RuntimeException(); + } + } + }); + try { + for (int i = 0; i < numberOfRequests; i++) { + requestObserver.onNext(SessionReq.getDefaultInstance()); + } + } catch (RuntimeException e) { + // Cancel the RPC. + requestObserver.onError(e); + throw e; + } + // Mark the end of requests. + requestObserver.onCompleted(); + // Wait for receiving to happen. + if (!finishLatch.await(10, SECONDS)) { + throw new RuntimeException(); + } + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + } + +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeWriter.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeWriter.java new file mode 100644 index 00000000000..0b398638f92 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeWriter.java @@ -0,0 +1,386 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.s2a.proto.v2.TLSVersion.TLS_VERSION_1_2; +import static com.google.s2a.proto.v2.TLSVersion.TLS_VERSION_1_3; + +import com.google.common.collect.ImmutableMap; +import com.google.common.io.CharStreams; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.protobuf.ByteString; +import com.google.s2a.proto.v2.Ciphersuite; +import com.google.s2a.proto.v2.ConnectionSide; +import com.google.s2a.proto.v2.GetTlsConfigurationReq; +import com.google.s2a.proto.v2.GetTlsConfigurationResp; +import com.google.s2a.proto.v2.OffloadPrivateKeyOperationReq; +import com.google.s2a.proto.v2.OffloadPrivateKeyOperationResp; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import com.google.s2a.proto.v2.SignatureAlgorithm; +import com.google.s2a.proto.v2.Status; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainReq; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainResp; +import io.grpc.stub.StreamObserver; +import io.grpc.util.CertificateUtils; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.Signature; +import java.security.spec.InvalidKeySpecException; + +/** A fake Writer Class to mock the behavior of S2A server. */ +final class FakeWriter implements StreamObserver { + /** Fake behavior of S2A service. */ + enum Behavior { + OK_STATUS, + EMPTY_RESPONSE, + ERROR_STATUS, + ERROR_RESPONSE, + COMPLETE_STATUS, + BAD_TLS_VERSION_RESPONSE, + } + + enum VerificationResult { + UNSPECIFIED, + SUCCESS, + FAILURE + } + + private static final ClassLoader classLoader = FakeWriter.class.getClassLoader(); + private static final ImmutableMap + ALGORITHM_TO_SIGNATURE_INSTANCE_IDENTIFIER = + ImmutableMap.of( + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256, + "SHA256withECDSA", + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384, + "SHA384withECDSA", + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512, + "SHA512withECDSA"); + + private boolean fakeWriterClosed = false; + private Behavior behavior = Behavior.OK_STATUS; + private StreamObserver reader; + private VerificationResult verificationResult = VerificationResult.UNSPECIFIED; + private String failureReason; + private PrivateKey privateKey; + + public static String convertInputStreamToString(InputStream is) throws IOException { + return CharStreams.toString(new InputStreamReader(is, StandardCharsets.UTF_8)); + } + + @CanIgnoreReturnValue + FakeWriter setReader(StreamObserver reader) { + this.reader = reader; + return this; + } + + @CanIgnoreReturnValue + FakeWriter setBehavior(Behavior behavior) { + this.behavior = behavior; + return this; + } + + @CanIgnoreReturnValue + FakeWriter setVerificationResult(VerificationResult verificationResult) { + this.verificationResult = verificationResult; + return this; + } + + @CanIgnoreReturnValue + FakeWriter setFailureReason(String failureReason) { + this.failureReason = failureReason; + return this; + } + + @CanIgnoreReturnValue + FakeWriter initializePrivateKey() throws InvalidKeySpecException, NoSuchAlgorithmException, + IOException, FileNotFoundException, UnsupportedEncodingException { + try ( + InputStream keyInputStream = classLoader.getResourceAsStream("leaf_key_ec.pem"); + ) { + privateKey = CertificateUtils.getPrivateKey(keyInputStream); + } + return this; + } + + @CanIgnoreReturnValue + FakeWriter resetPrivateKey() { + privateKey = null; + return this; + } + + void sendUnexpectedResponse() { + reader.onNext(SessionResp.getDefaultInstance()); + } + + void sendIoError() { + reader.onError(new IOException("Intended ERROR from FakeWriter.")); + } + + void sendGetTlsConfigResp() { + String leafCertString = ""; + String cert2String = ""; + String cert1String = ""; + try ( + InputStream leafCert = classLoader.getResourceAsStream("leaf_cert_ec.pem"); + InputStream cert2 = classLoader.getResourceAsStream("int_cert2_ec.pem"); + InputStream cert1 = classLoader.getResourceAsStream("int_cert1_ec.pem"); + ) { + leafCertString = FakeWriter.convertInputStreamToString(leafCert); + cert2String = FakeWriter.convertInputStreamToString(cert2); + cert1String = FakeWriter.convertInputStreamToString(cert1); + } catch (IOException e) { + reader.onError(e); + } + reader.onNext( + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(leafCertString) + .addCertificateChain(cert1String) + .addCertificateChain(cert2String) + .setMinTlsVersion(TLS_VERSION_1_3) + .setMaxTlsVersion(TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite + .CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build()); + } + + boolean isFakeWriterClosed() { + return fakeWriterClosed; + } + + @Override + public void onNext(SessionReq sessionReq) { + switch (behavior) { + case OK_STATUS: + try { + reader.onNext(handleResponse(sessionReq)); + } catch (IOException e) { + reader.onError(e); + } + break; + case EMPTY_RESPONSE: + reader.onNext(SessionResp.getDefaultInstance()); + break; + case ERROR_STATUS: + reader.onNext( + SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(1) + .setDetails("Intended ERROR Status from FakeWriter.")) + .build()); + break; + case ERROR_RESPONSE: + reader.onError(new S2AConnectionException("Intended ERROR from FakeWriter.")); + break; + case COMPLETE_STATUS: + reader.onCompleted(); + break; + case BAD_TLS_VERSION_RESPONSE: + String leafCertString = ""; + String cert2String = ""; + String cert1String = ""; + try ( + InputStream leafCert = classLoader.getResourceAsStream("leaf_cert_ec.pem"); + InputStream cert2 = classLoader.getResourceAsStream("int_cert2_ec.pem"); + InputStream cert1 = classLoader.getResourceAsStream("int_cert1_ec.pem"); + ) { + leafCertString = FakeWriter.convertInputStreamToString(leafCert); + cert2String = FakeWriter.convertInputStreamToString(cert2); + cert1String = FakeWriter.convertInputStreamToString(cert1); + } catch (IOException e) { + reader.onError(e); + } + reader.onNext( + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(leafCertString) + .addCertificateChain(cert1String) + .addCertificateChain(cert2String) + .setMinTlsVersion(TLS_VERSION_1_3) + .setMaxTlsVersion(TLS_VERSION_1_2))) + .build()); + break; + default: + try { + reader.onNext(handleResponse(sessionReq)); + } catch (IOException e) { + reader.onError(e); + } + } + } + + SessionResp handleResponse(SessionReq sessionReq) throws IOException { + if (sessionReq.hasGetTlsConfigurationReq()) { + return handleGetTlsConfigurationReq(sessionReq.getGetTlsConfigurationReq()); + } + + if (sessionReq.hasValidatePeerCertificateChainReq()) { + return handleValidatePeerCertificateChainReq(sessionReq.getValidatePeerCertificateChainReq()); + } + + if (sessionReq.hasOffloadPrivateKeyOperationReq()) { + return handleOffloadPrivateKeyOperationReq(sessionReq.getOffloadPrivateKeyOperationReq()); + } + + return SessionResp.newBuilder() + .setStatus( + Status.newBuilder().setCode(255).setDetails("No supported operation designated.")) + .build(); + } + + private SessionResp handleGetTlsConfigurationReq(GetTlsConfigurationReq req) + throws IOException { + if (!req.getConnectionSide().equals(ConnectionSide.CONNECTION_SIDE_CLIENT)) { + return SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(255) + .setDetails("No TLS configuration for the server side.")) + .build(); + } + String leafCertString = ""; + String cert2String = ""; + String cert1String = ""; + try ( + InputStream leafCert = classLoader.getResourceAsStream("leaf_cert_ec.pem"); + InputStream cert2 = classLoader.getResourceAsStream("int_cert2_ec.pem"); + InputStream cert1 = classLoader.getResourceAsStream("int_cert1_ec.pem"); + ) { + leafCertString = FakeWriter.convertInputStreamToString(leafCert); + cert2String = FakeWriter.convertInputStreamToString(cert2); + cert1String = FakeWriter.convertInputStreamToString(cert1); + } catch (IOException e) { + reader.onError(e); + } + return SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(leafCertString) + .addCertificateChain(cert1String) + .addCertificateChain(cert2String) + .setMinTlsVersion(TLS_VERSION_1_3) + .setMaxTlsVersion(TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + } + + private SessionResp handleValidatePeerCertificateChainReq(ValidatePeerCertificateChainReq req) { + if (verifyValidatePeerCertificateChainReq(req) + && verificationResult == VerificationResult.SUCCESS) { + return SessionResp.newBuilder() + .setValidatePeerCertificateChainResp( + ValidatePeerCertificateChainResp.newBuilder() + .setValidationResult(ValidatePeerCertificateChainResp.ValidationResult.SUCCESS)) + .build(); + } + return SessionResp.newBuilder() + .setValidatePeerCertificateChainResp( + ValidatePeerCertificateChainResp.newBuilder() + .setValidationResult( + verificationResult == VerificationResult.FAILURE + ? ValidatePeerCertificateChainResp.ValidationResult.FAILURE + : ValidatePeerCertificateChainResp.ValidationResult.UNSPECIFIED) + .setValidationDetails(failureReason)) + .build(); + } + + private boolean verifyValidatePeerCertificateChainReq(ValidatePeerCertificateChainReq req) { + if (req.getMode() != ValidatePeerCertificateChainReq.VerificationMode.UNSPECIFIED) { + return false; + } + if (req.getClientPeer().getCertificateChainCount() > 0) { + return true; + } + if (req.getServerPeer().getCertificateChainCount() > 0 + && !req.getServerPeer().getServerHostname().isEmpty()) { + return true; + } + return false; + } + + private SessionResp handleOffloadPrivateKeyOperationReq(OffloadPrivateKeyOperationReq req) { + if (privateKey == null) { + return SessionResp.newBuilder() + .setStatus(Status.newBuilder().setCode(255).setDetails("No Private Key available.")) + .build(); + } + String signatureIdentifier = + ALGORITHM_TO_SIGNATURE_INSTANCE_IDENTIFIER.get(req.getSignatureAlgorithm()); + if (signatureIdentifier == null) { + return SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(255) + .setDetails("Only ECDSA key algorithms are supported.")) + .build(); + } + + byte[] signature; + try { + Signature sig = Signature.getInstance(signatureIdentifier); + sig.initSign(privateKey); + sig.update(req.getRawBytes().toByteArray()); + signature = sig.sign(); + } catch (Exception e) { + return SessionResp.newBuilder() + .setStatus(Status.newBuilder().setCode(255).setDetails(e.getMessage())) + .build(); + } + + return SessionResp.newBuilder() + .setOffloadPrivateKeyOperationResp( + OffloadPrivateKeyOperationResp.newBuilder().setOutBytes(ByteString.copyFrom(signature))) + .build(); + } + + @Override + public void onError(Throwable t) { + throw new UnsupportedOperationException("onError is not supported by FakeWriter."); + } + + @Override + public void onCompleted() { + fakeWriterClosed = true; + reader.onCompleted(); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanismsTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanismsTest.java new file mode 100644 index 00000000000..c1c629366aa --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanismsTest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import com.google.common.truth.Expect; +import com.google.s2a.proto.v2.AuthenticationMechanism; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.grpc.s2a.internal.handshaker.tokenmanager.AccessTokenManager; +import io.grpc.s2a.internal.handshaker.tokenmanager.SingleTokenFetcher; +import java.util.Optional; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link GetAuthenticationMechanisms}. */ +@RunWith(JUnit4.class) +public final class GetAuthenticationMechanismsTest { + @Rule public final Expect expect = Expect.create(); + private static final String TOKEN = "access_token"; + private static String originalAccessToken; + private Optional tokenManager; + + @BeforeClass + public static void setUpClass() { + originalAccessToken = SingleTokenFetcher.getAccessToken(); + // Set the token that the client will use to authenticate to the S2A. + SingleTokenFetcher.setAccessToken(TOKEN); + } + + @Before + public void setUp() { + tokenManager = AccessTokenManager.create(); + } + + @AfterClass + public static void tearDownClass() { + SingleTokenFetcher.setAccessToken(originalAccessToken); + } + + @Test + public void getAuthMechanisms_emptyIdentity_success() { + expect + .that(GetAuthenticationMechanisms.getAuthMechanism(Optional.empty(), tokenManager)) + .isEqualTo( + Optional.of(AuthenticationMechanism.newBuilder().setToken("access_token").build())); + } + + @Test + public void getAuthMechanisms_nonEmptyIdentity_success() { + S2AIdentity fakeIdentity = S2AIdentity.fromSpiffeId("fake-spiffe-id"); + expect + .that(GetAuthenticationMechanisms.getAuthMechanism(Optional.of(fakeIdentity), tokenManager)) + .isEqualTo( + Optional.of( + AuthenticationMechanism.newBuilder() + .setIdentity(fakeIdentity.getIdentity()) + .setToken("access_token") + .build())); + } +} diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/ProtoUtilTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/ProtoUtilTest.java new file mode 100644 index 00000000000..28dbf0e4d88 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/ProtoUtilTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableSet; +import com.google.common.truth.Expect; +import com.google.s2a.proto.v2.TLSVersion; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ProtoUtil}. */ +@RunWith(JUnit4.class) +public final class ProtoUtilTest { + @Rule public final Expect expect = Expect.create(); + + @Test + public void convertTlsProtocolVersion_success() { + expect + .that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_3)) + .isEqualTo("TLSv1.3"); + expect + .that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_2)) + .isEqualTo("TLSv1.2"); + expect + .that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_1)) + .isEqualTo("TLSv1.1"); + expect.that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_0)).isEqualTo("TLSv1"); + } + + @Test + public void convertTlsProtocolVersion_withUnknownTlsVersion_fails() { + IllegalArgumentException expected = + assertThrows( + IllegalArgumentException.class, + () -> ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_UNSPECIFIED)); + expect.that(expected).hasMessageThat().isEqualTo("TLS version 0 is not supported."); + } + + @Test + public void buildTlsProtocolVersionSet_success() { + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_1_0, TLSVersion.TLS_VERSION_1_3)) + .isEqualTo(ImmutableSet.of("TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3")); + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_1_2, TLSVersion.TLS_VERSION_1_2)) + .isEqualTo(ImmutableSet.of("TLSv1.2")); + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_1_3, TLSVersion.TLS_VERSION_1_3)) + .isEqualTo(ImmutableSet.of("TLSv1.3")); + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_1_3, TLSVersion.TLS_VERSION_1_2)) + .isEmpty(); + } + + @Test + public void buildTlsProtocolVersionSet_failure() { + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_UNSPECIFIED, TLSVersion.TLS_VERSION_1_3)) + .isEqualTo(ImmutableSet.of("TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3")); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethodTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethodTest.java new file mode 100644 index 00000000000..8f71496cab8 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethodTest.java @@ -0,0 +1,318 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.truth.Expect; +import com.google.protobuf.ByteString; +import com.google.s2a.proto.v2.OffloadPrivateKeyOperationReq; +import com.google.s2a.proto.v2.OffloadPrivateKeyOperationResp; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import com.google.s2a.proto.v2.SignatureAlgorithm; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslPrivateKeyMethod; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.security.PublicKey; +import java.security.Signature; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Optional; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class S2APrivateKeyMethodTest { + @Rule public final Expect expect = Expect.create(); + private static final byte[] DATA_TO_SIGN = "random bytes for signing.".getBytes(UTF_8); + + private S2AStub stub; + private FakeWriter writer; + private S2APrivateKeyMethod keyMethod; + + private static PublicKey extractPublicKeyFromPem(String pem) throws Exception { + X509Certificate cert = + (X509Certificate) + CertificateFactory.getInstance("X.509") + .generateCertificate(new ByteArrayInputStream(pem.getBytes(UTF_8))); + return cert.getPublicKey(); + } + + private static boolean verifySignature( + byte[] dataToSign, byte[] signature, String signatureAlgorithm) throws Exception { + Signature sig = Signature.getInstance(signatureAlgorithm); + InputStream leafCert = + S2APrivateKeyMethodTest.class.getClassLoader().getResourceAsStream("leaf_cert_ec.pem"); + sig.initVerify(extractPublicKeyFromPem(FakeWriter.convertInputStreamToString( + leafCert))); + leafCert.close(); + sig.update(dataToSign); + return sig.verify(signature); + } + + @Before + public void setUp() { + // This is line is to ensure that JNI correctly links the necessary objects. Without this, we + // get `java.lang.UnsatisfiedLinkError` on + // `io.netty.internal.tcnative.NativeStaticallyReferencedJniMethods.sslSignRsaPkcsSha1()` + GrpcSslContexts.configure(SslContextBuilder.forClient()); + + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + keyMethod = S2APrivateKeyMethod.create(stub, /* localIdentity= */ Optional.empty()); + } + + @Test + public void signatureAlgorithmConversion_success() { + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA256); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA384); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA512); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA384); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA512); + } + + @Test + public void signatureAlgorithmConversion_unsupportedOperation() { + UnsupportedOperationException e = + assertThrows( + UnsupportedOperationException.class, + () -> S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg(-1)); + + assertThat(e).hasMessageThat().contains("Signature Algorithm -1 is not supported."); + } + + @Test + public void createOnNullStub_returnsNullPointerException() { + assertThrows( + NullPointerException.class, + () -> S2APrivateKeyMethod.create(/* stub= */ null, /* localIdentity= */ Optional.empty())); + } + + @Test + public void decrypt_unsupportedOperation() { + UnsupportedOperationException e = + assertThrows( + UnsupportedOperationException.class, + () -> keyMethod.decrypt(/* engine= */ null, DATA_TO_SIGN)); + + assertThat(e).hasMessageThat().contains("decrypt is not supported."); + } + + @Test + public void fakelocalIdentity_signWithSha256_success() throws Exception { + S2AIdentity fakeIdentity = S2AIdentity.fromSpiffeId("fake-spiffe-id"); + S2AStub mockStub = mock(S2AStub.class); + OpenSslPrivateKeyMethod keyMethodWithFakeIdentity = + S2APrivateKeyMethod.create(mockStub, Optional.of(fakeIdentity)); + SessionReq req = + SessionReq.newBuilder() + .setLocalIdentity(fakeIdentity.getIdentity()) + .setOffloadPrivateKeyOperationReq( + OffloadPrivateKeyOperationReq.newBuilder() + .setOperation(OffloadPrivateKeyOperationReq.PrivateKeyOperation.SIGN) + .setSignatureAlgorithm(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256) + .setRawBytes(ByteString.copyFrom(DATA_TO_SIGN))) + .build(); + byte[] expectedOutbytes = "fake out bytes".getBytes(UTF_8); + when(mockStub.send(req)) + .thenReturn( + SessionResp.newBuilder() + .setOffloadPrivateKeyOperationResp( + OffloadPrivateKeyOperationResp.newBuilder() + .setOutBytes(ByteString.copyFrom(expectedOutbytes))) + .build()); + + byte[] signature = + keyMethodWithFakeIdentity.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN); + verify(mockStub).send(req); + assertThat(signature).isEqualTo(expectedOutbytes); + } + + @Test + public void signWithSha256_success() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + byte[] signature = + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN); + + assertThat(signature).isNotEmpty(); + assertThat(verifySignature(DATA_TO_SIGN, signature, "SHA256withECDSA")).isTrue(); + } + + @Test + public void signWithSha384_success() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + byte[] signature = + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384, + DATA_TO_SIGN); + + assertThat(signature).isNotEmpty(); + assertThat(verifySignature(DATA_TO_SIGN, signature, "SHA384withECDSA")).isTrue(); + } + + @Test + public void signWithSha512_success() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + byte[] signature = + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512, + DATA_TO_SIGN); + + assertThat(signature).isNotEmpty(); + assertThat(verifySignature(DATA_TO_SIGN, signature, "SHA512withECDSA")).isTrue(); + } + + @Test + public void sign_noKeyAvailable() throws Exception { + writer.resetPrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN)); + + assertThat(e) + .hasMessageThat() + .contains( + "Error occurred in response from S2A, error code: 255, error message: \"No Private Key" + + " available.\"."); + } + + @Test + public void sign_algorithmNotSupported() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256, + DATA_TO_SIGN)); + + assertThat(e) + .hasMessageThat() + .contains( + "Error occurred in response from S2A, error code: 255, error message: \"Only ECDSA key" + + " algorithms are supported.\"."); + } + + @Test + public void sign_getsErrorResponse() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.ERROR_STATUS); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN)); + + assertThat(e) + .hasMessageThat() + .contains( + "Error occurred in response from S2A, error code: 1, error message: \"Intended ERROR" + + " Status from FakeWriter.\"."); + } + + @Test + public void sign_getsEmptyResponse() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.EMPTY_RESPONSE); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN)); + + assertThat(e).hasMessageThat().contains("No valid response received from S2A."); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactoryTest.java new file mode 100644 index 00000000000..7e776f16da2 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactoryTest.java @@ -0,0 +1,259 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import com.google.common.testing.NullPointerTester; +import com.google.common.testing.NullPointerTester.Visibility; +import com.google.s2a.proto.v2.S2AServiceGrpc; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import io.grpc.Channel; +import io.grpc.InsecureChannelCredentials; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.benchmarks.Utils; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.internal.TestUtils.NoopChannelLogger; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory.S2AProtocolNegotiator; +import io.grpc.stub.StreamObserver; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http2.Http2ConnectionDecoder; +import io.netty.handler.codec.http2.Http2ConnectionEncoder; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.util.AsciiString; +import java.io.IOException; +import java.util.Optional; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AProtocolNegotiatorFactory}. */ +@RunWith(JUnit4.class) +public class S2AProtocolNegotiatorFactoryTest { + private static final S2AIdentity LOCAL_IDENTITY = S2AIdentity.fromSpiffeId("local identity"); + private final ChannelHandlerContext mockChannelHandlerContext = mock(ChannelHandlerContext.class); + private GrpcHttp2ConnectionHandler fakeConnectionHandler; + private String authority; + private int port; + private Server fakeS2AServer; + private ObjectPool channelPool; + + @Before + public void setUp() throws Exception { + port = Utils.pickUnusedPort(); + fakeS2AServer = ServerBuilder.forPort(port).addService(new S2AServiceImpl()).build(); + fakeS2AServer.start(); + channelPool = new FakeChannelPool(); + authority = "localhost:" + port; + fakeConnectionHandler = FakeConnectionHandler.create(authority); + } + + @After + public void tearDown() { + fakeS2AServer.shutdown(); + } + + @Test + public void handlerRemoved_success() throws Exception { + S2AProtocolNegotiatorFactory.BufferReadsHandler handler1 = + new S2AProtocolNegotiatorFactory.BufferReadsHandler(); + S2AProtocolNegotiatorFactory.BufferReadsHandler handler2 = + new S2AProtocolNegotiatorFactory.BufferReadsHandler(); + EmbeddedChannel channel = new EmbeddedChannel(handler1, handler2); + channel.writeInbound("message1"); + channel.writeInbound("message2"); + channel.writeInbound("message3"); + assertThat(handler1.getReads()).hasSize(3); + assertThat(handler2.getReads()).isEmpty(); + channel.pipeline().remove(handler1); + assertThat(handler2.getReads()).hasSize(3); + } + + @Test + public void createProtocolNegotiatorFactory_nullArgument() throws Exception { + NullPointerTester tester = new NullPointerTester().setDefault(Optional.class, Optional.empty()); + + tester.testStaticMethods(S2AProtocolNegotiatorFactory.class, Visibility.PUBLIC); + } + + @Test + public void createProtocolNegotiator_nullArgument() throws Exception { + ObjectPool pool = + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource( + "localhost:8080", InsecureChannelCredentials.create())); + + NullPointerTester tester = + new NullPointerTester() + .setDefault(ObjectPool.class, pool) + .setDefault(Optional.class, Optional.empty()); + + tester.testStaticMethods(S2AProtocolNegotiator.class, Visibility.PACKAGE); + } + + @Test + public void createProtocolNegotiatorFactory_getsDefaultPort_succeeds() throws Exception { + InternalProtocolNegotiator.ClientFactory clientFactory = + S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null); + + assertThat(clientFactory.getDefaultPort()).isEqualTo(S2AProtocolNegotiatorFactory.DEFAULT_PORT); + } + + @Test + public void s2aProtocolNegotiator_getHostNameOnNull_returnsNull() throws Exception { + assertThat(S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.getHostNameFromAuthority(null)) + .isNull(); + } + + @Test + public void s2aProtocolNegotiator_getHostNameOnValidAuthority_returnsValidHostname() + throws Exception { + assertThat( + S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.getHostNameFromAuthority( + "hostname:80")) + .isEqualTo("hostname"); + } + + @Test + public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClientSide_succeeds() + throws Exception { + InternalProtocolNegotiator.ClientFactory clientFactory = + S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null); + + ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); + + assertThat(clientNegotiator).isInstanceOf(S2AProtocolNegotiator.class); + assertThat(clientNegotiator.scheme()).isEqualTo(AsciiString.of("https")); + } + + @Test + public void closeProtocolNegotiator_verifyProtocolNegotiatorIsClosedOnClientSide() + throws Exception { + InternalProtocolNegotiator.ClientFactory clientFactory = + S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null); + ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); + + clientNegotiator.close(); + + assertThat(((FakeChannelPool) channelPool).isChannelCached()).isFalse(); + } + + @Test + public void createChannelHandler_addHandlerToMockContext() throws Exception { + ProtocolNegotiator clientNegotiator = + S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.createForClient( + channelPool, LOCAL_IDENTITY, null); + + ChannelHandler channelHandler = clientNegotiator.newHandler(fakeConnectionHandler); + + ((ChannelDuplexHandler) channelHandler).userEventTriggered(mockChannelHandlerContext, "event"); + verify(mockChannelHandlerContext).fireUserEventTriggered("event"); + } + + /** A {@code GrpcHttp2ConnectionHandler} that does nothing. */ + private static class FakeConnectionHandler extends GrpcHttp2ConnectionHandler { + private static final Http2ConnectionDecoder DECODER = mock(Http2ConnectionDecoder.class); + private static final Http2ConnectionEncoder ENCODER = mock(Http2ConnectionEncoder.class); + private static final Http2Settings SETTINGS = new Http2Settings(); + private final String authority; + + static FakeConnectionHandler create(String authority) { + return new FakeConnectionHandler(null, DECODER, ENCODER, SETTINGS, authority); + } + + private FakeConnectionHandler( + ChannelPromise channelUnused, + Http2ConnectionDecoder decoder, + Http2ConnectionEncoder encoder, + Http2Settings initialSettings, + String authority) { + super(channelUnused, decoder, encoder, initialSettings, new NoopChannelLogger()); + this.authority = authority; + } + + @Override + public String getAuthority() { + return authority; + } + } + + /** An S2A server that handles GetTlsConfiguration request. */ + private static class S2AServiceImpl extends S2AServiceGrpc.S2AServiceImplBase { + static final FakeWriter writer = new FakeWriter(); + + @Override + public StreamObserver setUpSession(StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(SessionReq req) { + try { + responseObserver.onNext(writer.handleResponse(req)); + } catch (IOException e) { + responseObserver.onError(e); + } + } + + @Override + public void onError(Throwable t) {} + + @Override + public void onCompleted() {} + }; + } + } + + private static class FakeChannelPool implements ObjectPool { + private final Channel mockChannel = mock(Channel.class); + private @Nullable Channel cachedChannel = null; + + @Override + public Channel getObject() { + if (cachedChannel == null) { + cachedChannel = mockChannel; + } + return cachedChannel; + } + + @Override + public Channel returnObject(Object object) { + assertThat(object).isSameInstanceAs(mockChannel); + cachedChannel = null; + return null; + } + + public boolean isChannelCached() { + return (cachedChannel != null); + } + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AStubTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AStubTest.java new file mode 100644 index 00000000000..2c7a7dd8405 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AStubTest.java @@ -0,0 +1,285 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.truth.Expect; +import com.google.s2a.proto.v2.Ciphersuite; +import com.google.s2a.proto.v2.ConnectionSide; +import com.google.s2a.proto.v2.GetTlsConfigurationReq; +import com.google.s2a.proto.v2.GetTlsConfigurationResp; +import com.google.s2a.proto.v2.S2AServiceGrpc; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import com.google.s2a.proto.v2.Status; +import com.google.s2a.proto.v2.TLSVersion; +import io.grpc.Channel; +import io.grpc.InsecureChannelCredentials; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.io.InputStream; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AStub}. */ +@RunWith(JUnit4.class) +public class S2AStubTest { + @Rule public final Expect expect = Expect.create(); + private static final String S2A_ADDRESS = "localhost:8080"; + private S2AStub stub; + private FakeWriter writer; + + @Before + public void setUp() { + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + } + + @Test + public void send_receiveOkStatus() throws Exception { + SessionReq req = + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build(); + + SessionResp resp = stub.send(req); + + assertThat(resp.hasGetTlsConfigurationResp()).isTrue(); + assertThat(resp.getGetTlsConfigurationResp().hasClientTlsConfiguration()).isTrue(); + } + + @Test + public void send_clientTlsConfiguration_receiveOkStatus() throws Exception { + SessionReq req = + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build(); + + SessionResp resp = stub.send(req); + + String leafCertString = ""; + String cert2String = ""; + String cert1String = ""; + ClassLoader classLoader = S2AStubTest.class.getClassLoader(); + try ( + InputStream leafCert = classLoader.getResourceAsStream("leaf_cert_ec.pem"); + InputStream cert2 = classLoader.getResourceAsStream("int_cert2_ec.pem"); + InputStream cert1 = classLoader.getResourceAsStream("int_cert1_ec.pem"); + ) { + leafCertString = FakeWriter.convertInputStreamToString(leafCert); + cert2String = FakeWriter.convertInputStreamToString(cert2); + cert1String = FakeWriter.convertInputStreamToString(cert1); + } + + SessionResp expected = + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(leafCertString) + .addCertificateChain(cert1String) + .addCertificateChain(cert2String) + .setMinTlsVersion(TLSVersion.TLS_VERSION_1_3) + .setMaxTlsVersion(TLSVersion.TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + assertThat(resp).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void send_serverTlsConfiguration_receiveErrorStatus() throws Exception { + SessionReq req = + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_SERVER)) + .build(); + + SessionResp resp = stub.send(req); + + SessionResp expected = + SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(255) + .setDetails("No TLS configuration for the server side.")) + .build(); + assertThat(resp).isEqualTo(expected); + } + + @Test + public void send_receiveErrorStatus() throws Exception { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + + SessionResp resp = stub.send(SessionReq.getDefaultInstance()); + + SessionResp expected = + SessionResp.newBuilder() + .setStatus( + Status.newBuilder().setCode(1).setDetails("Intended ERROR Status from FakeWriter.")) + .build(); + assertThat(resp).isEqualTo(expected); + } + + @Test + public void send_receiveErrorResponse() throws InterruptedException { + writer.setBehavior(FakeWriter.Behavior.ERROR_RESPONSE); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + expect.that(expected).hasCauseThat().isInstanceOf(RuntimeException.class); + expect.that(expected).hasMessageThat().contains("Intended ERROR from FakeWriter."); + } + + @Test + public void send_receiveCompleteStatus() throws Exception { + writer.setBehavior(FakeWriter.Behavior.COMPLETE_STATUS); + + ConnectionClosedException expected = + assertThrows( + ConnectionClosedException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected).hasMessageThat().contains("Reading from the S2A is complete."); + } + + @Test + public void send_receiveUnexpectedResponse() throws Exception { + writer.sendIoError(); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected) + .hasMessageThat() + .contains( + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable."); + } + + @Test + public void send_receiveManyUnexpectedResponse_expectResponsesEmpty() throws Exception { + writer.sendIoError(); + writer.sendIoError(); + writer.sendIoError(); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected) + .hasMessageThat() + .contains( + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable."); + + assertThat(stub.getResponses()).isEmpty(); + } + + @Test + public void send_receiveDelayedResponse() throws Exception { + writer.sendGetTlsConfigResp(); + IOException expectedException = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + assertThat(expectedException) + .hasMessageThat() + .contains("Received an unexpected response from a host at the S2A's address."); + + assertThat(stub.getResponses()).isEmpty(); + } + + @Test + public void send_afterEarlyClose_receivesClosedException() throws InterruptedException { + stub.close(); + expect.that(writer.isFakeWriterClosed()).isTrue(); + + ConnectionClosedException expected = + assertThrows( + ConnectionClosedException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected).hasMessageThat().contains("Stream to the S2A is closed."); + } + + @Test + public void send_withUnavailableService_throwsDeadlineExceeded() throws Exception { + ObjectPool channelPool = + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource( + S2A_ADDRESS, InsecureChannelCredentials.create())); + S2AServiceGrpc.S2AServiceStub serviceStub = S2AServiceGrpc.newStub(channelPool.getObject()); + S2AStub newStub = S2AStub.newInstanceWithDeadline(serviceStub, 1); + + IOException expected = + assertThrows(IOException.class, () -> newStub.send(SessionReq.getDefaultInstance())); + + assertThat(expected).hasMessageThat().contains("DEADLINE_EXCEEDED"); + } + + @Test + public void send_failToWrite() throws Exception { + FailWriter failWriter = new FailWriter(); + stub = S2AStub.newInstanceForTesting(failWriter); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + expect.that(expected).hasCauseThat().isInstanceOf(S2AConnectionException.class); + expect + .that(expected) + .hasCauseThat() + .hasMessageThat() + .isEqualTo("Could not send request to S2A."); + } + + /** Fails whenever a write is attempted. */ + private static class FailWriter implements StreamObserver { + @Override + public void onNext(SessionReq req) { + assertThat(req).isNotNull(); + throw new S2AConnectionException("Could not send request to S2A."); + } + + @Override + public void onError(Throwable t) { + assertThat(t).isInstanceOf(S2AConnectionException.class); + } + + @Override + public void onCompleted() { + throw new UnsupportedOperationException(); + } + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2ATrustManagerTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2ATrustManagerTest.java new file mode 100644 index 00000000000..198001838aa --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2ATrustManagerTest.java @@ -0,0 +1,262 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import java.io.ByteArrayInputStream; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Base64; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class S2ATrustManagerTest { + private S2AStub stub; + private FakeWriter writer; + private static final String FAKE_HOSTNAME = "Fake-Hostname"; + private static final String CLIENT_CERT_PEM = + "MIICKjCCAc+gAwIBAgIUC2GShcVO+5Zkml+7VO3OQ+B2c7EwCgYIKoZIzj0EAwIw" + + "HzEdMBsGA1UEAwwUcm9vdGNlcnQuZXhhbXBsZS5jb20wIBcNMjMwMTI2MTk0OTUx" + + "WhgPMjA1MDA2MTMxOTQ5NTFaMB8xHTAbBgNVBAMMFGxlYWZjZXJ0LmV4YW1wbGUu" + + "Y29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEeciYZgFAZjxyzTrklCRIWpad" + + "8wkyCZQzJSf0IfNn9NKtfzL2V/blteULO0o9Da8e2Avaj+XCKfFTc7salMo/waOB" + + "5jCB4zAOBgNVHQ8BAf8EBAMCB4AwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsG" + + "AQUFBwMBMAwGA1UdEwEB/wQCMAAwYQYDVR0RBFowWIYic3BpZmZlOi8vZm9vLnBy" + + "b2QuZ29vZ2xlLmNvbS9wMS9wMoIUZm9vLnByb2Quc3BpZmZlLmdvb2eCHG1hY2hp" + + "bmUtbmFtZS5wcm9kLmdvb2dsZS5jb20wHQYDVR0OBBYEFETY6Cu/aW924nfvUrOs" + + "yXCC1hrpMB8GA1UdIwQYMBaAFJLkXGlTYKISiGd+K/Ijh4IOEpHBMAoGCCqGSM49" + + "BAMCA0kAMEYCIQCZDW472c1/4jEOHES/88X7NTqsYnLtIpTjp5PZ62z3sAIhAN1J" + + "vxvbxt9ySdFO+cW7oLBEkCwUicBhxJi5VfQeQypT"; + + @Before + public void setUp() { + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + } + + @Test + public void createForClient_withNullStub_throwsError() { + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + S2ATrustManager.createForClient( + /* stub= */ null, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().isNull(); + } + + @Test + public void createForClient_withNullHostname_throwsError() { + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + S2ATrustManager.createForClient( + stub, /* hostname= */ null, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().isNull(); + } + + @Test + public void getAcceptedIssuers_returnsExpectedNullResult() { + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + assertThat(trustManager.getAcceptedIssuers()).isNull(); + } + + @Test + public void checkClientTrusted_withEmptyCertificateChain_throwsException() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + IllegalArgumentException expected = + assertThrows( + IllegalArgumentException.class, + () -> trustManager.checkClientTrusted(new X509Certificate[] {}, /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Certificate chain has zero certificates."); + } + + @Test + public void checkServerTrusted_withEmptyCertificateChain_throwsException() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + IllegalArgumentException expected = + assertThrows( + IllegalArgumentException.class, + () -> trustManager.checkServerTrusted(new X509Certificate[] {}, /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Certificate chain has zero certificates."); + } + + @Test + public void checkClientTrusted_getsSuccessResponse() throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + // Expect no exception. + trustManager.checkClientTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkClientTrusted_withLocalIdentity_getsSuccessResponse() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient( + stub, FAKE_HOSTNAME, Optional.of(S2AIdentity.fromSpiffeId("fake-spiffe-id"))); + + // Expect no exception. + trustManager.checkClientTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkServerTrusted_getsSuccessResponse() throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + // Expect no exception. + trustManager.checkServerTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkServerTrusted_withLocalIdentity_getsSuccessResponse() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient( + stub, FAKE_HOSTNAME, Optional.of(S2AIdentity.fromSpiffeId("fake-spiffe-id"))); + + // Expect no exception. + trustManager.checkServerTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkClientTrusted_getsIntendedFailureResponse() throws CertificateException { + writer + .setVerificationResult(FakeWriter.VerificationResult.FAILURE) + .setFailureReason("Intended failure."); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkClientTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Intended failure."); + } + + @Test + public void checkClientTrusted_getsIntendedFailureStatusInResponse() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkClientTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Error occurred in response from S2A"); + } + + @Test + public void checkClientTrusted_getsIntendedFailureFromServer() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_RESPONSE); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkClientTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().isEqualTo("Failed to send request to S2A."); + } + + @Test + public void checkServerTrusted_getsIntendedFailureResponse() throws CertificateException { + writer + .setVerificationResult(FakeWriter.VerificationResult.FAILURE) + .setFailureReason("Intended failure."); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkServerTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Intended failure."); + } + + @Test + public void checkServerTrusted_getsIntendedFailureStatusInResponse() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkServerTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Error occurred in response from S2A"); + } + + @Test + public void checkServerTrusted_getsIntendedFailureFromServer() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_RESPONSE); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkServerTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().isEqualTo("Failed to send request to S2A."); + } + + private X509Certificate[] getCerts() throws CertificateException { + byte[] decoded = Base64.getDecoder().decode(CLIENT_CERT_PEM); + return new X509Certificate[] { + (X509Certificate) + CertificateFactory.getInstance("X.509") + .generateCertificate(new ByteArrayInputStream(decoded)) + }; + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/SslContextFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/SslContextFactoryTest.java new file mode 100644 index 00000000000..17b834abf2a --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/SslContextFactoryTest.java @@ -0,0 +1,177 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.truth.Expect; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslSessionContext; +import io.netty.handler.ssl.SslContext; +import java.security.GeneralSecurityException; +import java.util.Optional; +import javax.net.ssl.SSLSessionContext; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link SslContextFactory}. */ +@RunWith(JUnit4.class) +public final class SslContextFactoryTest { + @Rule public final Expect expect = Expect.create(); + private static final String FAKE_TARGET_NAME = "fake_target_name"; + private S2AStub stub; + private FakeWriter writer; + + @Before + public void setUp() { + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + } + + @Test + public void createForClient_returnsValidSslContext() throws Exception { + SslContext sslContext = + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty()); + + expect.that(sslContext).isNotNull(); + expect.that(sslContext.sessionCacheSize()).isEqualTo(1); + expect.that(sslContext.sessionTimeout()).isEqualTo(300); + expect.that(sslContext.isClient()).isTrue(); + expect.that(sslContext.applicationProtocolNegotiator().protocols()).containsExactly("h2"); + SSLSessionContext sslSessionContext = sslContext.sessionContext(); + if (sslSessionContext instanceof OpenSslSessionContext) { + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + expect.that(openSslSessionContext.isSessionCacheEnabled()).isFalse(); + } + } + + @Test + public void createForClient_withLocalIdentity_returnsValidSslContext() throws Exception { + SslContext sslContext = + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, Optional.of(S2AIdentity.fromSpiffeId("fake-spiffe-id"))); + + expect.that(sslContext).isNotNull(); + expect.that(sslContext.sessionCacheSize()).isEqualTo(1); + expect.that(sslContext.sessionTimeout()).isEqualTo(300); + expect.that(sslContext.isClient()).isTrue(); + expect.that(sslContext.applicationProtocolNegotiator().protocols()).containsExactly("h2"); + SSLSessionContext sslSessionContext = sslContext.sessionContext(); + if (sslSessionContext instanceof OpenSslSessionContext) { + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + expect.that(openSslSessionContext.isSessionCacheEnabled()).isFalse(); + } + } + + @Test + public void createForClient_returnsEmptyResponse_error() throws Exception { + writer.setBehavior(FakeWriter.Behavior.EMPTY_RESPONSE); + + S2AConnectionException expected = + assertThrows( + S2AConnectionException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .contains("Response from S2A server does NOT contain ClientTlsConfiguration."); + } + + @Test + public void createForClient_returnsErrorStatus_error() throws Exception { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + + S2AConnectionException expected = + assertThrows( + S2AConnectionException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().contains("Intended ERROR Status from FakeWriter."); + } + + @Test + public void createForClient_getsErrorFromServer_throwsError() throws Exception { + writer.sendIoError(); + + GeneralSecurityException expected = + assertThrows( + GeneralSecurityException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .contains("Failed to get client TLS configuration from S2A."); + } + + @Test + public void createForClient_getsBadTlsVersionsFromServer_throwsError() throws Exception { + writer.setBehavior(FakeWriter.Behavior.BAD_TLS_VERSION_RESPONSE); + + S2AConnectionException expected = + assertThrows( + S2AConnectionException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .contains("Set of TLS versions received from S2A server is empty or not supported."); + } + + @Test + public void createForClient_nullStub_throwsError() throws Exception { + writer.sendUnexpectedResponse(); + + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + SslContextFactory.createForClient( + /* stub= */ null, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().isEqualTo("stub should not be null."); + } + + @Test + public void createForClient_nullTargetName_throwsError() throws Exception { + writer.sendUnexpectedResponse(); + + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + SslContextFactory.createForClient( + stub, /* targetName= */ null, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .isEqualTo("targetName should not be null on client side."); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java new file mode 100644 index 00000000000..9fd33fe9070 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java @@ -0,0 +1,80 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker.tokenmanager; + +import static com.google.common.truth.Truth.assertThat; + +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import java.util.Optional; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SingleTokenAccessTokenManagerTest { + private static final S2AIdentity IDENTITY = S2AIdentity.fromSpiffeId("spiffe_id"); + private static final String TOKEN = "token"; + + private String originalAccessToken; + + @Before + public void setUp() { + originalAccessToken = SingleTokenFetcher.getAccessToken(); + SingleTokenFetcher.setAccessToken(null); + } + + @After + public void tearDown() { + SingleTokenFetcher.setAccessToken(originalAccessToken); + } + + @Test + public void getDefaultToken_success() throws Exception { + SingleTokenFetcher.setAccessToken(TOKEN); + Optional manager = AccessTokenManager.create(); + assertThat(manager).isPresent(); + assertThat(manager.get().getDefaultToken()).isEqualTo(TOKEN); + } + + @Test + public void getToken_success() throws Exception { + SingleTokenFetcher.setAccessToken(TOKEN); + Optional manager = AccessTokenManager.create(); + assertThat(manager).isPresent(); + assertThat(manager.get().getToken(IDENTITY)).isEqualTo(TOKEN); + } + + @Test + public void getToken_noEnvironmentVariable() throws Exception { + assertThat(SingleTokenFetcher.create()).isEmpty(); + } + + @Test + public void create_success() throws Exception { + SingleTokenFetcher.setAccessToken(TOKEN); + Optional manager = AccessTokenManager.create(); + assertThat(manager).isPresent(); + assertThat(manager.get().getToken(IDENTITY)).isEqualTo(TOKEN); + } + + @Test + public void create_noEnvironmentVariable() throws Exception { + assertThat(AccessTokenManager.create()).isEmpty(); + } +} diff --git a/s2a/src/test/resources/README.md b/s2a/src/test/resources/README.md new file mode 100644 index 00000000000..2250ffb1dec --- /dev/null +++ b/s2a/src/test/resources/README.md @@ -0,0 +1,69 @@ +# Generating certificates and keys for testing mTLS-S2A + +Content from: https://github.com/google/s2a-go/blob/main/testdata/README.md + +Create root CA + +``` +openssl req -x509 -sha256 -days 7305 -newkey rsa:2048 -keyout root_key.pem -out +root_cert.pem +``` + +Generate private keys for server and client + +``` +openssl genrsa -out server_key.pem 2048 +openssl genrsa -out client_key.pem 2048 +``` + +Generate CSRs for server and client (set Common Name to localhost, leave all +other fields blank) + +``` +openssl req -key server_key.pem -new -out server.csr -config config.cnf +openssl req -key client_key.pem -new -out client.csr -config config.cnf +``` + +Sign CSRs for server and client + +``` +openssl x509 -req -CA root_cert.pem -CAkey root_key.pem -in server.csr -out server_cert.pem -days 7305 -extfile config.cnf -extensions req_ext +openssl x509 -req -CA root_cert.pem -CAkey root_key.pem -in client.csr -out client_cert.pem -days 7305 +``` + +Generate self-signed ECDSA root cert + +``` +openssl ecparam -name prime256v1 -genkey -noout -out temp.pem +openssl pkcs8 -topk8 -in temp.pem -out root_key_ec.pem -nocrypt +rm temp.pem +openssl req -x509 -days 7305 -new -key root_key_ec.pem -nodes -out root_cert_ec.pem -config root_ec.cnf -extensions 'v3_req' +``` + +Generate a chain of ECDSA certs + +``` +openssl ecparam -name prime256v1 -genkey -noout -out temp.pem +openssl pkcs8 -topk8 -in temp.pem -out int_key2_ec.pem -nocrypt +rm temp.pem +openssl req -key int_key2_ec.pem -new -out temp.csr -config int_cert2.cnf +openssl x509 -req -days 7305 -in temp.csr -CA root_cert_ec.pem -CAkey root_key_ec.pem -CAcreateserial -out int_cert2_ec.pem -extfile int_cert2.cnf -extensions 'v3_req' + + +openssl ecparam -name prime256v1 -genkey -noout -out temp.pem +openssl pkcs8 -topk8 -in temp.pem -out int_key1_ec.pem -nocrypt +rm temp.pem +openssl req -key int_key1_ec.pem -new -out temp.csr -config int_cert1.cnf +openssl x509 -req -days 7305 -in temp.csr -CA int_cert2_ec.pem -CAkey int_key2_ec.pem -CAcreateserial -out int_cert1_ec.pem -extfile int_cert1.cnf -extensions 'v3_req' + + +openssl ecparam -name prime256v1 -genkey -noout -out temp.pem +openssl pkcs8 -topk8 -in temp.pem -out leaf_key_ec.pem -nocrypt +rm temp.pem +openssl req -key leaf_key_ec.pem -new -out temp.csr -config leaf.cnf +openssl x509 -req -days 7305 -in temp.csr -CA int_cert1_ec.pem -CAkey int_key1_ec.pem -CAcreateserial -out leaf_cert_ec.pem -extfile leaf.cnf -extensions 'v3_req' +``` + +``` +cat leaf_cert_ec.pem int_cert1_ec.pem int_cert2_ec.pem > cert_chain_ec.pem +``` \ No newline at end of file diff --git a/s2a/src/test/resources/cert_chain_ec.pem b/s2a/src/test/resources/cert_chain_ec.pem new file mode 100644 index 00000000000..a249904286c --- /dev/null +++ b/s2a/src/test/resources/cert_chain_ec.pem @@ -0,0 +1,39 @@ +-----BEGIN CERTIFICATE----- +MIIB6jCCAZCgAwIBAgIUA98F2JkYZAyz9BdIkBK3P8Df7OUwCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFaW50MU8xDzANBgNVBAsMBmludDFPVTEPMA0GA1UEAwwGaW50 +MUNOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +bGVhZk8xDzANBgNVBAsMBmxlYWZPVTEPMA0GA1UEAwwGbGVhZkNOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAEtpTTzt2VDTP6gO4uUIpg8sB63Ff4T4YPMoIGrrn3 +tU3f9j0Ysa5/xblM0LkwRImcrKKchYDiNm1wHkWo+qDImaOBgzCBgDAOBgNVHQ8B +Af8EBAMCB4AwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMAwGA1Ud +EwEB/wQCMAAwHQYDVR0OBBYEFGzFBt/E6vDJRcH+Izy4MQ9AHycqMB8GA1UdIwQY +MBaAFBYs72Jv682/xzG3Tm8hItIFis//MAoGCCqGSM49BAMCA0gAMEUCIHUcqPTB +mQ4kXE0WoOUC8ZmzvthvfKjCNe0YogcjZgwWAiEAvapmWoQIO4qie25Ae9sYRCPq +5xAHztAquk5HLfwabow= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIB8TCCAZagAwIBAgIUEXwpznJIlU+ELO7Qgb4UUGpfbj8wCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFaW50Mk8xDzANBgNVBAsMBmludDJPVTEPMA0GA1UEAwwGaW50 +MkNOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +aW50MU8xDzANBgNVBAsMBmludDFPVTEPMA0GA1UEAwwGaW50MUNOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAEoenicrtL6ezEW2yLSXADscDJQ/fdbr+vJEU/aieV +wA2EnPbrdpvQZaz+pXtuZzBLZY50XI9y33E+/PvBFtZob6OBiTCBhjAOBgNVHQ8B +Af8EBAMCAQYwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIGA1Ud +EwEB/wQIMAYBAf8CAQEwHQYDVR0OBBYEFBYs72Jv682/xzG3Tm8hItIFis//MB8G +A1UdIwQYMBaAFPhN6eGgVc36Kc50rREZhMdBIkgGMAoGCCqGSM49BAMCA0kAMEYC +IQDiPcbihg1iDi0m9CUn96IbWOTh1X75RfVJYcR3Q5T78AIhAK/fxZauDeWPzk2r +2/ohCQOZFHtAi9VRpr/TqNi3SaYt +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIB8DCCAZagAwIBAgIUNOH4wQEoKHvaQ9Xgd36vh5TnhfUwCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFcm9vdE8xDzANBgNVBAsMBnJvb3RPVTEPMA0GA1UEAwwGcm9v +dENOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +aW50Mk8xDzANBgNVBAsMBmludDJPVTEPMA0GA1UEAwwGaW50MkNOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAE44B/G4pzAvLpIUaPp8XNRtXuw8jeLgE40NjQMuqq +3jNs6ID/fv/jiRggLMXL3Tii1CisM4BRjg56/Owky1Fyv6OBiTCBhjAOBgNVHQ8B +Af8EBAMCAQYwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIGA1Ud +EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFPhN6eGgVc36Kc50rREZhMdBIkgGMB8G +A1UdIwQYMBaAFNHNBlllqi9koRtf7EBHjRMwVgWsMAoGCCqGSM49BAMCA0gAMEUC +IBd4bvqVeYSSUEGF1wB0KlYxn1L0Ub/LjgIUUQFAEwahAiEAgeArX63bnlI7u3dq +v/FGilvcLP3P3AvRozpHJiIZ860= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/client_cert.pem b/s2a/src/test/resources/client_cert.pem new file mode 100644 index 00000000000..837f8bb5019 --- /dev/null +++ b/s2a/src/test/resources/client_cert.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDPTCCAiWgAwIBAgIUaarddwSWeE4jDC9kwxEr446ehqUwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTI0MTAwMTIxNTk1NFoXDTQ0MTAwMTIxNTk1NFowFDESMBAGA1UEAwwJbG9jYWxo +b3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAxlNsldt7yAU4KRuS +2D2/FjNIE1US5olBm4HteTr++41WaELZJqNLRPPp052jEQU3aKSYNGZvUUO6buu7 +eFpz2SBNUVMyvmzzocjVAyyf4NQvDazYHWOb+/YCeUppTRWriz4V5sn47qJTQ8cd +CGrTFeLHxUjx4nh/OiqVXP/KnF3EqPEuqph0ky7+GirnJgPRe+C5ERuGkJye8dmP +yWGA2lSS6MeDe7JZTAMi08bAn7BuNpeBkOzz1msGGI9PnUanUs7GOPWTDdcQAVY8 +KMvHCuGaNMGpb4rOR2mm8LlbAbpTPz8Pkw4QtMCLkgsrz2CzXpVwnLsU7nDXJAIO +B155lQIDAQABo0IwQDAdBgNVHQ4EFgQUSZEyIHLzkIw7AwkBaUjYfIrGVR4wHwYD +VR0jBBgwFoAUcq3dtxAVA410YWyM0B4e+4umbiwwDQYJKoZIhvcNAQELBQADggEB +AAz0bZ4ayrZLhA45xn0yvdpdqiCtiWikCRtxgE7VXHg/ziZJVMpBpAhbIGO5tIyd +lttnRXHwz5DUwKiba4/bCEFe229BshQEql5qaqcbGbFfSly11WeqqnwR1N7c8Gpv +pD9sVrx22seN0rTUk87MY/S7mzCxHqAx35zm/LTW3pWcgCTMKFHy4Gt4mpTnXkNA +WkhP2OhW5RLiu6Whi0BEdb2TGG1+ctamgijKXb+gJeef5ehlHXG8eU862KF5UlEA +NeQKBm/PpQxOMe0NdpatjN8QRoczku0Itiodng+OZ1o+2iSNG988uFRb3CUSnjtE +R/HL6ULAFzo59EpIYxruU/w= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/client_key.pem b/s2a/src/test/resources/client_key.pem new file mode 100644 index 00000000000..38b93eb65c4 --- /dev/null +++ b/s2a/src/test/resources/client_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDGU2yV23vIBTgp +G5LYPb8WM0gTVRLmiUGbge15Ov77jVZoQtkmo0tE8+nTnaMRBTdopJg0Zm9RQ7pu +67t4WnPZIE1RUzK+bPOhyNUDLJ/g1C8NrNgdY5v79gJ5SmlNFauLPhXmyfjuolND +xx0IatMV4sfFSPHieH86KpVc/8qcXcSo8S6qmHSTLv4aKucmA9F74LkRG4aQnJ7x +2Y/JYYDaVJLox4N7sllMAyLTxsCfsG42l4GQ7PPWawYYj0+dRqdSzsY49ZMN1xAB +Vjwoy8cK4Zo0walvis5HaabwuVsBulM/Pw+TDhC0wIuSCyvPYLNelXCcuxTucNck +Ag4HXnmVAgMBAAECggEAKuW9jXaBgiS63o1jyFkmvWcPNntG0M2sfrXuRzQfFgse +vwOCk8xrSflWQNsOe+58ayp6746ekl3LdBWSIbiy6SqG/sm3pp/LXNmjVYHv/QH4 +QYV643R5t1ihdVnGiBFhXwdpVleme/tpdjYZzgnJKak5W69o/nrgzhSK5ShAy2xM +j0XXbgdqG+4JxPb5BZmjHHfXAXUfgSORMdfArkbgFBRc9wL/6JVTXjeAMy5WX9qe +5UQsSOYkwc9P2snifC/jdIhjHQOkkx59O0FgukJEFZPoagVG1duWQbnNDr7QVHCJ +jV6dg9tIT4SXD3uPSPbgNGlRUseIakCzrhHARJuA2wKBgQD/h8zoh0KaqKyViCYw +XKOFpm1pAFnp2GiDOblxNubNFAXEWnC+FlkvO/z1s0zVuYELUqfxcYMSXJFEVelK +rfjZtoC5oxqWGqLo9iCj7pa8t+ipulYcLt2SWc7eZPD4T4lzeEf1Qz77aKcz34sa +dv9lzQkDvhR/Mv1VeEGFHiq2VwKBgQDGsLcTGH5Yxs//LRSY8TigBkQEDrH5NvXu +2jtAzZhy1Yhsoa5eiZkhnnzM6+n05ovfZLcy6s7dnwP1Y+C79vs+DKMBsodtDG5z +YpsB0VrXYa6P6pCqkcz0Bz9xdo5sOhAK3AKnX6jd29XBDdeYsw/lxHLG24wProTD +cCYFqtaj8wKBgQCaqKT68DL9zK14a8lBaDCIyexaqx3AjXzkP+Hfhi03XrEG4P5v +7rLYBeTbCUSt7vMN2V9QoTWFvYUm6SCkVJvTmcRblz6WL1T+z0l+LwAJBP7LC77m +m+77j2PH8yxt/iXhP6G97o+GNxdMLDbTM8bs5KZaH4fkXQY73uc5HMMZTQKBgEZS +7blYhf+t/ph2wD+RwVUCYrh86wkmJs2veCFro3WhlnO8lhbn5Mc9bTaqmVgQ8ZjT +8POYoDdYvPHxs+1TcYF4v4kuQziZmc5FLE/sZZauADb38tQsXrpQhmgGakpsEpmF +XXsYJJDB6lo2KATn+8x7R5SSyHQUdPEnlI2U9ft5AoGBAJw0NJiM1EzRS8xq0DmO +AvQaPjo01o2hH6wghws8gDQwrj0eHraHgVi7zo0VkaHJbO7ahKPudset3N7owJhA +CUAPPRtv5wn0amAyNz77f1dz4Gys3AkcchflqhbEaQpzKYx4kX0adclur4WJ/DVm +P7DI977SHCVB4FVMbXMEkBjN +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/config.cnf b/s2a/src/test/resources/config.cnf new file mode 100644 index 00000000000..5f9a7710e92 --- /dev/null +++ b/s2a/src/test/resources/config.cnf @@ -0,0 +1,17 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = req_ext + +[req_distinguished_name] +countryName = Country Name (2 letter code) +stateOrProvinceName = State or Province Name (full name) +localityName = Locality Name (eg, city) +organizationalUnitName = Organizational Unit Name (eg, section) +commonName = Common Name (eg, your name or your server\'s hostname) +emailAddress = Email Address + +[req_ext] +subjectAltName = @alt_names + +[alt_names] +IP.1 = ::1 \ No newline at end of file diff --git a/s2a/src/test/resources/int_cert1_.cnf b/s2a/src/test/resources/int_cert1_.cnf new file mode 100644 index 00000000000..ba5a0f66a5e --- /dev/null +++ b/s2a/src/test/resources/int_cert1_.cnf @@ -0,0 +1,14 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = v3_req +prompt = no + +[req_distinguished_name] +O = int1O +OU = int1OU +CN = int1CN + +[v3_req] +keyUsage = critical, keyCertSign, cRLSign +extendedKeyUsage = critical, clientAuth, serverAuth +basicConstraints = critical, CA:true, pathlen: 1 \ No newline at end of file diff --git a/s2a/src/test/resources/int_cert1_ec.pem b/s2a/src/test/resources/int_cert1_ec.pem new file mode 100644 index 00000000000..de83c2aba79 --- /dev/null +++ b/s2a/src/test/resources/int_cert1_ec.pem @@ -0,0 +1,13 @@ +-----BEGIN CERTIFICATE----- +MIIB8TCCAZagAwIBAgIUEXwpznJIlU+ELO7Qgb4UUGpfbj8wCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFaW50Mk8xDzANBgNVBAsMBmludDJPVTEPMA0GA1UEAwwGaW50 +MkNOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +aW50MU8xDzANBgNVBAsMBmludDFPVTEPMA0GA1UEAwwGaW50MUNOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAEoenicrtL6ezEW2yLSXADscDJQ/fdbr+vJEU/aieV +wA2EnPbrdpvQZaz+pXtuZzBLZY50XI9y33E+/PvBFtZob6OBiTCBhjAOBgNVHQ8B +Af8EBAMCAQYwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIGA1Ud +EwEB/wQIMAYBAf8CAQEwHQYDVR0OBBYEFBYs72Jv682/xzG3Tm8hItIFis//MB8G +A1UdIwQYMBaAFPhN6eGgVc36Kc50rREZhMdBIkgGMAoGCCqGSM49BAMCA0kAMEYC +IQDiPcbihg1iDi0m9CUn96IbWOTh1X75RfVJYcR3Q5T78AIhAK/fxZauDeWPzk2r +2/ohCQOZFHtAi9VRpr/TqNi3SaYt +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/int_cert2.cnf b/s2a/src/test/resources/int_cert2.cnf new file mode 100644 index 00000000000..f48524effb2 --- /dev/null +++ b/s2a/src/test/resources/int_cert2.cnf @@ -0,0 +1,14 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = v3_req +prompt = no + +[req_distinguished_name] +O = int2O +OU = int2OU +CN = int2CN + +[v3_req] +keyUsage = critical, keyCertSign, cRLSign +extendedKeyUsage = critical, clientAuth, serverAuth +basicConstraints = critical, CA:true, pathlen: 2 \ No newline at end of file diff --git a/s2a/src/test/resources/int_cert2_ec.pem b/s2a/src/test/resources/int_cert2_ec.pem new file mode 100644 index 00000000000..4f502fda808 --- /dev/null +++ b/s2a/src/test/resources/int_cert2_ec.pem @@ -0,0 +1,13 @@ +-----BEGIN CERTIFICATE----- +MIIB8DCCAZagAwIBAgIUNOH4wQEoKHvaQ9Xgd36vh5TnhfUwCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFcm9vdE8xDzANBgNVBAsMBnJvb3RPVTEPMA0GA1UEAwwGcm9v +dENOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +aW50Mk8xDzANBgNVBAsMBmludDJPVTEPMA0GA1UEAwwGaW50MkNOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAE44B/G4pzAvLpIUaPp8XNRtXuw8jeLgE40NjQMuqq +3jNs6ID/fv/jiRggLMXL3Tii1CisM4BRjg56/Owky1Fyv6OBiTCBhjAOBgNVHQ8B +Af8EBAMCAQYwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIGA1Ud +EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFPhN6eGgVc36Kc50rREZhMdBIkgGMB8G +A1UdIwQYMBaAFNHNBlllqi9koRtf7EBHjRMwVgWsMAoGCCqGSM49BAMCA0gAMEUC +IBd4bvqVeYSSUEGF1wB0KlYxn1L0Ub/LjgIUUQFAEwahAiEAgeArX63bnlI7u3dq +v/FGilvcLP3P3AvRozpHJiIZ860= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/int_key1_ec.pem b/s2a/src/test/resources/int_key1_ec.pem new file mode 100644 index 00000000000..909c119b60c --- /dev/null +++ b/s2a/src/test/resources/int_key1_ec.pem @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgnYGMzs4siZ7Fy3mI +rmsqBdP6We4Zt+ndtOYEGaZDj06hRANCAASh6eJyu0vp7MRbbItJcAOxwMlD991u +v68kRT9qJ5XADYSc9ut2m9BlrP6le25nMEtljnRcj3LfcT78+8EW1mhv +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/int_key2_ec.pem b/s2a/src/test/resources/int_key2_ec.pem new file mode 100644 index 00000000000..520300d2560 --- /dev/null +++ b/s2a/src/test/resources/int_key2_ec.pem @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgzLSoAcENXIiQfBS7 +meBDCohT1rofhWSfD0m55qi8V3WhRANCAATjgH8binMC8ukhRo+nxc1G1e7DyN4u +ATjQ2NAy6qreM2zogP9+/+OJGCAsxcvdOKLUKKwzgFGODnr87CTLUXK/ +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/leaf.cnf b/s2a/src/test/resources/leaf.cnf new file mode 100644 index 00000000000..c21cee5568f --- /dev/null +++ b/s2a/src/test/resources/leaf.cnf @@ -0,0 +1,14 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = v3_req +prompt = no + +[req_distinguished_name] +O = leafO +OU = leafOU +CN = leafCN + +[v3_req] +keyUsage = critical, digitalSignature +extendedKeyUsage = critical, clientAuth, serverAuth +basicConstraints = critical, CA:false \ No newline at end of file diff --git a/s2a/src/test/resources/leaf_cert_ec.pem b/s2a/src/test/resources/leaf_cert_ec.pem new file mode 100644 index 00000000000..ca48b821f60 --- /dev/null +++ b/s2a/src/test/resources/leaf_cert_ec.pem @@ -0,0 +1,13 @@ +-----BEGIN CERTIFICATE----- +MIIB6jCCAZCgAwIBAgIUA98F2JkYZAyz9BdIkBK3P8Df7OUwCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFaW50MU8xDzANBgNVBAsMBmludDFPVTEPMA0GA1UEAwwGaW50 +MUNOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +bGVhZk8xDzANBgNVBAsMBmxlYWZPVTEPMA0GA1UEAwwGbGVhZkNOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAEtpTTzt2VDTP6gO4uUIpg8sB63Ff4T4YPMoIGrrn3 +tU3f9j0Ysa5/xblM0LkwRImcrKKchYDiNm1wHkWo+qDImaOBgzCBgDAOBgNVHQ8B +Af8EBAMCB4AwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMAwGA1Ud +EwEB/wQCMAAwHQYDVR0OBBYEFGzFBt/E6vDJRcH+Izy4MQ9AHycqMB8GA1UdIwQY +MBaAFBYs72Jv682/xzG3Tm8hItIFis//MAoGCCqGSM49BAMCA0gAMEUCIHUcqPTB +mQ4kXE0WoOUC8ZmzvthvfKjCNe0YogcjZgwWAiEAvapmWoQIO4qie25Ae9sYRCPq +5xAHztAquk5HLfwabow= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/leaf_key_ec.pem b/s2a/src/test/resources/leaf_key_ec.pem new file mode 100644 index 00000000000..b92b90ba1da --- /dev/null +++ b/s2a/src/test/resources/leaf_key_ec.pem @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgkvnGZBh3uIYfZiau +/0qN0YcQXlwwVVUh8EybjvKUlX2hRANCAAS2lNPO3ZUNM/qA7i5QimDywHrcV/hP +hg8yggauufe1Td/2PRixrn/FuUzQuTBEiZysopyFgOI2bXAeRaj6oMiZ +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/root_cert.pem b/s2a/src/test/resources/root_cert.pem new file mode 100644 index 00000000000..ccd0a46bc23 --- /dev/null +++ b/s2a/src/test/resources/root_cert.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUWemeXZdfqcqkP8/Eyj74oTJtoNQwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTI0MTAwMTIxNTkxMVoXDTQ0MTAwMTIxNTkxMVowWTELMAkGA1UEBhMCQVUxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAt3A04hy5lljv86Nu0LLQZ2hA+fcImHjt1p1Mxgcta/5oxfVLcerE +ZH+DAQLDtWzp9Up/vI57MM419GIL8Iszk7hnZRS/HWJ+2jewZJtz4i/g15dLr6+1 +uabMdPOWos60BwcLMxKEe6lJO1mV4z9d4NH4mAuMIHyM+ty0Klp9MfeDJtYEh0+z +AxJUHCixDTsnKJro7My7A3ZT7bvaMfXxS7XN6qlRgBfiCmXo/GKTFfmfBW/EZGkG +XOCxE2D79wYNhC41Q/ix0kwjEeOj2vgGFoiyblSdHdzvRXzsoQTEiZSM8lJDR2IT +ZbpgbBlknMU6efNWlS8P5damB9ZWXg3x4wIDAQABo1MwUTAdBgNVHQ4EFgQUcq3d +txAVA410YWyM0B4e+4umbiwwHwYDVR0jBBgwFoAUcq3dtxAVA410YWyM0B4e+4um +biwwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEApZvaI9y7vjX/ +RRdvwf2Db9KlTE9nuVQ3AsrmG9Ml0p2X6U5aTetxdYBo2PuaaYHheF03JOH8zjpL +UfFzvbi52DPbfFAaDw/6NIAenXlg492leNvUFNjGGRyJO9R5/aDfv40/fT3Em5G5 +DnR8SeGQ9tI1t6xBBT+d+/MilSiEKVu8IIF/p0SwvEyR4pKo6wFVZR0ZiIj2v/FZ +P5Qk0Xhb+slpmaR3Wtx/mPl9Wb3kpPD4CAwhWDqFkKJql9/n9FvMjdwlCQKQGB26 +ZDXY3C0UTdktK5biNWRgAUVJEWBX6Q2amrxQHIn2d9RJ8uxCME/KBAntK+VxZE78 +w0JOvQ4Dpw== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/root_cert_ec.pem b/s2a/src/test/resources/root_cert_ec.pem new file mode 100644 index 00000000000..3d20dcfe83c --- /dev/null +++ b/s2a/src/test/resources/root_cert_ec.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBxzCCAW2gAwIBAgIUN+H7Td9dhyvMrrzZhanevAfCN34wCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFcm9vdE8xDzANBgNVBAsMBnJvb3RPVTEPMA0GA1UEAwwGcm9v +dENOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +cm9vdE8xDzANBgNVBAsMBnJvb3RPVTEPMA0GA1UEAwwGcm9vdENOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAEGnS2gVv6Bs0GtuUAOebR9E0fqaj3zi9mD97B/dgi +MLENhtVPJQzeePv6Ccap+73O0BINRNOl8tlHX0YaXDeEHKNhMF8wDgYDVR0PAQH/ +BAQDAgEGMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAPBgNVHRMBAf8E +BTADAQH/MB0GA1UdDgQWBBTRzQZZZaovZKEbX+xAR40TMFYFrDAKBggqhkjOPQQD +AgNIADBFAiEAgnIyLs7FsZNsJjFgYzlaut4h23RxrpUYVCVZt/+x1Q0CIG3U6WGz +YaEyKoCtBHH9cAy76+pP/NU2f7/QuHU9Vymd +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/root_ec.cnf b/s2a/src/test/resources/root_ec.cnf new file mode 100644 index 00000000000..d736865c831 --- /dev/null +++ b/s2a/src/test/resources/root_ec.cnf @@ -0,0 +1,14 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = v3_req +prompt = no + +[req_distinguished_name] +O = rootO +OU = rootOU +CN = rootCN + +[v3_req] +keyUsage = critical, keyCertSign, cRLSign +extendedKeyUsage = serverAuth, clientAuth +basicConstraints = critical, CA:true \ No newline at end of file diff --git a/s2a/src/test/resources/root_key.pem b/s2a/src/test/resources/root_key.pem new file mode 100644 index 00000000000..34d0ffa61eb --- /dev/null +++ b/s2a/src/test/resources/root_key.pem @@ -0,0 +1,30 @@ +-----BEGIN ENCRYPTED PRIVATE KEY----- +MIIFJDBWBgkqhkiG9w0BBQ0wSTAxBgkqhkiG9w0BBQwwJAQQJXNe391O3gaNbKLw +o60XrQICCAAwDAYIKoZIhvcNAgkFADAUBggqhkiG9w0DBwQI4pf69+BBF8IEggTI +JuQ3p67U9k/NWMuYXaR9a6lv24YZ1qR6ieL5B6keCaCDVoQMb5V22O0vBqCVePgr +EG0yWIeeAsARMzAxE7Lnil6abSe7tij+LjEI9F7mV/1QSFt03PLVI+e7OcKNI+Nr +6vISEi8CaddekP8JDRhPMpgdWderZvogo3REpJ8GNIUddQzu1e3ZgDtOPquqcgqb +MH/HuPE3vjj4/l6ZpX+6DZKIvzjwtBQ4PMzSWLumzmYLItd3kz7UryN+9hKluSZp +D2KB24aUIQFbDxe2DMTi5c0QIiyzjwkv081ecNJOy2gYX3uiucr8/Ax3o21RNZtI +oKCmSPVEfYdrkdfkwuSOioVTbWBZBcSZo3L2bmCkSXTuheGurEw/TtQWXBgew0Bn +UQjEJgZy96PVsQeu3t+NRCacARQi4vfv7PVHlQW8fcfcC6CeNw7VIZ8aS7supqym +RJxzMY9ZnLwO9cgybXLYgosVZnvI7nOokJPfO1+KqBK01C1Sgc3tg8czKhRuztHu +qDO0GCZ7l+9/ku/WIy/5NiatNvRo5dMAOGxsSrjI9a7+EmenoIfd8/KREVX19D+R +gZRALVATHq83rF6BdsyTwya1QUr/J24EIlkOc4HbCBm5WxA2ZjNdDBZ+KhivYaS7 +l1qrbkFOhmBD9kYRbseBrxlzKUWJMGhOpw3xebut3HngLqyezLcjsXQuF3Iau5Hl +9QFcmSdLj2ZlNlQvmfNJX/r6a/K2LigruXCbvHWMqVsHd7XZdWJ/8wjm2AL97iON +mYFLP+ScfYom9qrF41jNkUKZiLk/ppvSHyWBAqbze+R9Zfpcf8ArCwuAL/JlEMzv +YkBv1DWKfzJpZHYX695MxrpS3C8m0IyXNxktBL3KTVvwZaIhSNBlNS3fdb9m8toR +Tz/LS8jseWpZ5D552/+KAa0Skhav3ZFpxmAS8BEyE/nI9Dwg9niYcZLWORWHAQPp +jraG0BkE7bn5No/k7E4rjFb+2N+36QxVacJI3neC8bQXVHP0BVUvrabOWFPnGivl +Ok91Eo8q5PUAsd15ZnKjTHzlD7zv7fF6ncBgj3P4L2Xrs6P34JOZEd4wixEUZYeC +Xe+SZrFyUr6CcNC45C6R3hDYqmrz0GK1ikkis3XcKT+C5flBYb9NRx8G9wyCuS6H +oHl0Rfbpc47wQTuajicMVO2El7syMPUAxjo3EfMzvjm7uCXLTHnXRnRt3Y5AkPGa +0kFE9Vm00PReRfQ7qbSUiOOHYa9NIsw1l2ZI+knP9XbY2HikELOpjgucrMxZF+ms +zit5YGD3NGZi5xcHZFZTs9L8kaJccXn5DtjA30eEiFzKqMtMKnwlrbSL55I1JXim +co1RLpRK2KQmtJHo1br3RH6jP7fePYzgDceDds5HKWz22pYFcVtlx4DeYH5vjdEp +i3yNQZ32jD2HYhgCK325QLP5S2UYmUOPWd4sEiwZMBPpPOlt0TqCdFKYgS2GHlSN +IYVBYelPUYsz9Kg0TFtLMZLNUmwsXJ+jqnLVtmFyoV6IIvbSCqQ9jxTbZQKxThK8 +A1G+nXBO41ZW8eQZUGx8CzbCj2JvtVThgErSRqAuYbvlUt7EI4Ac8veZC8rJIG0Q +ADkueb978o4OI6vpOdTYCmdTIoHWlpup +-----END ENCRYPTED PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/root_key_ec.pem b/s2a/src/test/resources/root_key_ec.pem new file mode 100644 index 00000000000..5560a66d414 --- /dev/null +++ b/s2a/src/test/resources/root_key_ec.pem @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgjfTyzPIlKV0zANQP +2s1C2FhbenE34QEsf83wjpuQrZWhRANCAAQadLaBW/oGzQa25QA55tH0TR+pqPfO +L2YP3sH92CIwsQ2G1U8lDN54+/oJxqn7vc7QEg1E06Xy2UdfRhpcN4Qc +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/server_cert.pem b/s2a/src/test/resources/server_cert.pem new file mode 100644 index 00000000000..909b83aa903 --- /dev/null +++ b/s2a/src/test/resources/server_cert.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDWjCCAkKgAwIBAgIUAeWzyzIEetYf+ZWHj9NzH1JkLYkwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTI0MTAwMTIxNTk0NloXDTQ0MTAwMTIxNTk0NlowFDESMBAGA1UEAwwJbG9jYWxo +b3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1qnW7Pb06MgRNLzt +icv/ydl8W/lpPRjrJJb04/TtXbJ1hjnp7i796TfNGrJgHqEZnaR8q83lO0L38B2X +sJ04b3R+y+6HhH8+MbHejM7ybrTZRNQXip/Kxu4QLHBTQEsplycWLf42/R3cIk/X +vgxq5NsCsbk4xI4xwlcqC8FM1AHU0VrKxzHWVhZEM+/KovBAr/hRYln9CukeKjOf +UiVq58uuDAlJRC3yH2Rd/sqCDELvqRv17J6eYx2nJ3mSN5aBa0FwVjg6vr5Obddj +AWWIkgrlAr+a+OraxOrWElFfChBSvr/qHdJFWHeCdq/SAhow5uRhC69ScJf+7lrX +hsj1sQIDAQABo18wXTAbBgNVHREEFDAShxAAAAAAAAAAAAAAAAAAAAABMB0GA1Ud +DgQWBBRdDRg6GuDj8Sujmz4/rqfP0jZHbTAfBgNVHSMEGDAWgBRyrd23EBUDjXRh +bIzQHh77i6ZuLDANBgkqhkiG9w0BAQsFAAOCAQEAAEUS27+6p88CWYemMOY0iu0e +mp4YqG0XQSilbSnxrqnJb3N8pR3Yh6JJKnblQ6xdexfzrXlBA/v7nx+f8e9HS2QZ +KLtEIaEvNKL51JdOS6ebEzLVvhk98r2kpKM3wpT++/18HPlPK5W3rMQNsLOyAdvP +UX6TakhIfflRjz1DYXQ1ERvJOFw2HEmw6K6r2VwBhZKfwwzxmAHpVwniWXGbgyRF +79hG6rO1tv1K5LHAPIRs0h2Lh/VPxm2XiaNkdGyarUy5/NM+GoHErgxOBmYltn5Q +vAlZrgF2/mSXcUb7EHoXvoC9L4M7U/dRQD4Q1fQRJ/KjrhbDAC3gfZ4zorKoaQ== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/server_key.pem b/s2a/src/test/resources/server_key.pem new file mode 100644 index 00000000000..edc37cb3855 --- /dev/null +++ b/s2a/src/test/resources/server_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDWqdbs9vToyBE0 +vO2Jy//J2Xxb+Wk9GOsklvTj9O1dsnWGOenuLv3pN80asmAeoRmdpHyrzeU7Qvfw +HZewnThvdH7L7oeEfz4xsd6MzvJutNlE1BeKn8rG7hAscFNASymXJxYt/jb9Hdwi +T9e+DGrk2wKxuTjEjjHCVyoLwUzUAdTRWsrHMdZWFkQz78qi8ECv+FFiWf0K6R4q +M59SJWrny64MCUlELfIfZF3+yoIMQu+pG/Xsnp5jHacneZI3loFrQXBWODq+vk5t +12MBZYiSCuUCv5r46trE6tYSUV8KEFK+v+od0kVYd4J2r9ICGjDm5GELr1Jwl/7u +WteGyPWxAgMBAAECggEAFEAgcOlZME6TZPS/ueSfRET6mNieB2/+2sxM3OZhsBmi +QZ/cBCa1uFcVx8N1Et6iwn7ebfy199G4/xNjmHs0dDs6rPVbHnI8hUag1oq9TxlL +d9VERUUOxZZ2uyJ7kBCnI0XCL2OQf29eMXRzx093lBBfIDH3e39ojUtYwZQiMcuw +EPry0k4fVhymhKg9Wnmt5lMg4Mdc1TpPfmNFuTR0PZ1nAaVQglvH66qNKGVoWEhZ +paNLaKC4H2Jfa1AfAWl6Efy5JDMOfHF0ww0cDUrTzAeQ7jEh0UGyL1lX8W6kKRDa +0quUqxOJz9aQ8cyd27s2OQMlRtbXi/jhhVp7WLIrWQKBgQD9gKG5CgBO/L8nIj5o +EhHFhtfjEhdeXTAlenmxoBxUN7Pwkc2OvhNef7+T0+euwl50ieopWLoRxLZ2yY8l +E2b2+7EM6/8/wgt1bCVh5NCWrE63tLCx+wdht1oqciDXvuv5bJTf73sipgDTYYSV +gE+DHXq96mxVJXo1TLtQQpXMVQKBgQDYx0AbO0KP2TTNY5ChqVwthaETHjWs6z9p +U5WRgNYeXbUKg3l7JJk6zq72ZIBeqEr3d9mJqrk6HFKTh4c+LyjKyLjmY5wkmfHh +s6s1lCEgEoXKT3Fa+DxlsXltyxrJLzuf1h276jeL5bB6BmJNKLODcEoCx/ubrwOj +prdUSWqf7QKBgQCO/sg7AJE7/QY2pPJe8hJkQbP1unbEG/zUp0mOEKrqNqGhyh0R +r9ZtL9J5KMc/pRRy2Hjl6c7LxxLF3tyIJXGnUEKG73iEFokwK1jK569hzsB4j8w8 +GUYIsMyDtO0hxeiGQeGYkBX9bXZ5xkBrtH0lkLNz/ZAuV32gIzBmDalCIQKBgDGT +f+m6Z8KWHilKt+0A2n/eq7O/mO7u7hWcc/xOxqkzLRA2eTXcbN6yHfljiqgbPOnT +kwCU9r9/crMir59dEasutHqcFT2Zp2PCv0kFk33OPqLCAF6ZntZy/B5L8NhJ4Qzw +3uP28LUh1nZRt3GF+Wf56jMwoS49nEt0+UBhee0RAoGAS9YsJkbjBg2p3Gxvo5c0 +IjfZdcyS2ndTjXv+hFvkjMw0ULFT3dqpk+0asaCh5nrDUbVQyan+D8LgwSwNZy89 +e99bl//oliv/Om7lVFCKtBOhe+fIWHlrR0e2bemsQi/pgTURjYFuvjhR50dcKx96 +jLHvG4mTfStHaJ1gKGWvgWA= +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/services/BUILD.bazel b/services/BUILD.bazel index d20e956ed49..ba9d334a5c9 100644 --- a/services/BUILD.bazel +++ b/services/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") load("//:java_grpc_library.bzl", "java_grpc_library") @@ -121,6 +122,7 @@ java_library( "@io_grpc_grpc_proto//:reflection_java_proto", "@io_grpc_grpc_proto//:reflection_java_proto_deprecated", artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), ], ) @@ -139,6 +141,7 @@ java_library( "//stub", "@io_grpc_grpc_proto//:health_java_proto", artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), ], ) diff --git a/services/build.gradle b/services/build.gradle index fade7aef3fb..c30e1ba53bd 100644 --- a/services/build.gradle +++ b/services/build.gradle @@ -32,14 +32,16 @@ dependencies { runtimeOnly libraries.errorprone.annotations, libraries.gson // to fix checkUpperBoundDeps error here - compileOnly libraries.javax.annotation testImplementation project(':grpc-testing'), project(':grpc-inprocess'), libraries.netty.transport.epoll, // for DomainSocketAddress testFixtures(project(':grpc-core')), testFixtures(project(':grpc-api')) - testCompileOnly libraries.javax.annotation - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } configureProtoCompilation() @@ -58,6 +60,7 @@ tasks.named("jacocoTestReport").configure { '**/io/grpc/binarylog/v1/**', '**/io/grpc/channelz/v1/**', '**/io/grpc/health/v1/**', + '**/io/grpc/reflection/v1/**', '**/io/grpc/reflection/v1alpha/**', ]) } diff --git a/services/src/generated/main/grpc/io/grpc/channelz/v1/ChannelzGrpc.java b/services/src/generated/main/grpc/io/grpc/channelz/v1/ChannelzGrpc.java index b3c1c285c8f..c4ac4076d22 100644 --- a/services/src/generated/main/grpc/io/grpc/channelz/v1/ChannelzGrpc.java +++ b/services/src/generated/main/grpc/io/grpc/channelz/v1/ChannelzGrpc.java @@ -8,9 +8,6 @@ * information. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/channelz/v1/channelz.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ChannelzGrpc { @@ -250,6 +247,21 @@ public ChannelzStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOpt return ChannelzStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ChannelzBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ChannelzBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ChannelzBlockingV2Stub(channel, callOptions); + } + }; + return ChannelzBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -481,6 +493,98 @@ public void getSocket(io.grpc.channelz.v1.GetSocketRequest request, * information. * */ + public static final class ChannelzBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ChannelzBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ChannelzBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ChannelzBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Gets all root channels (i.e. channels the application has directly
+     * created). This does not include subchannels nor non-top level channels.
+     * 
+ */ + public io.grpc.channelz.v1.GetTopChannelsResponse getTopChannels(io.grpc.channelz.v1.GetTopChannelsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetTopChannelsMethod(), getCallOptions(), request); + } + + /** + *
+     * Gets all servers that exist in the process.
+     * 
+ */ + public io.grpc.channelz.v1.GetServersResponse getServers(io.grpc.channelz.v1.GetServersRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetServersMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns a single Server, or else a NOT_FOUND code.
+     * 
+ */ + public io.grpc.channelz.v1.GetServerResponse getServer(io.grpc.channelz.v1.GetServerRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetServerMethod(), getCallOptions(), request); + } + + /** + *
+     * Gets all server sockets that exist in the process.
+     * 
+ */ + public io.grpc.channelz.v1.GetServerSocketsResponse getServerSockets(io.grpc.channelz.v1.GetServerSocketsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetServerSocketsMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns a single Channel, or else a NOT_FOUND code.
+     * 
+ */ + public io.grpc.channelz.v1.GetChannelResponse getChannel(io.grpc.channelz.v1.GetChannelRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetChannelMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns a single Subchannel, or else a NOT_FOUND code.
+     * 
+ */ + public io.grpc.channelz.v1.GetSubchannelResponse getSubchannel(io.grpc.channelz.v1.GetSubchannelRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetSubchannelMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns a single Socket or else a NOT_FOUND code.
+     * 
+ */ + public io.grpc.channelz.v1.GetSocketResponse getSocket(io.grpc.channelz.v1.GetSocketRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetSocketMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service Channelz. + *
+   * Channelz is a service exposed by gRPC servers that provides detailed debug
+   * information.
+   * 
+ */ public static final class ChannelzBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ChannelzBlockingStub( diff --git a/services/src/generated/main/grpc/io/grpc/health/v1/HealthGrpc.java b/services/src/generated/main/grpc/io/grpc/health/v1/HealthGrpc.java index 73ddd4e0d23..b8e94ef7d20 100644 --- a/services/src/generated/main/grpc/io/grpc/health/v1/HealthGrpc.java +++ b/services/src/generated/main/grpc/io/grpc/health/v1/HealthGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/health/v1/health.proto") @io.grpc.stub.annotations.GrpcGenerated public final class HealthGrpc { @@ -91,6 +88,21 @@ public HealthStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptio return HealthStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static HealthBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public HealthBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new HealthBlockingV2Stub(channel, callOptions); + } + }; + return HealthBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -225,6 +237,58 @@ public void watch(io.grpc.health.v1.HealthCheckRequest request, /** * A stub to allow clients to do synchronous rpc calls to service Health. */ + public static final class HealthBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private HealthBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected HealthBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new HealthBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * If the requested service is unknown, the call will fail with status
+     * NOT_FOUND.
+     * 
+ */ + public io.grpc.health.v1.HealthCheckResponse check(io.grpc.health.v1.HealthCheckRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getCheckMethod(), getCallOptions(), request); + } + + /** + *
+     * Performs a watch for the serving status of the requested service.
+     * The server will immediately send back a message indicating the current
+     * serving status.  It will then subsequently send a new message whenever
+     * the service's serving status changes.
+     * If the requested service is unknown when the call is received, the
+     * server will send a message setting the serving status to
+     * SERVICE_UNKNOWN but will *not* terminate the call.  If at some
+     * future point, the serving status of the service becomes known, the
+     * server will send a new message with the service's serving status.
+     * If the call terminates with status UNIMPLEMENTED, then clients
+     * should assume this method is not supported and should not retry the
+     * call.  If the call terminates with any other status (including OK),
+     * clients should retry the call with appropriate exponential backoff.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + watch(io.grpc.health.v1.HealthCheckRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getWatchMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service Health. + */ public static final class HealthBlockingStub extends io.grpc.stub.AbstractBlockingStub { private HealthBlockingStub( diff --git a/services/src/generated/main/grpc/io/grpc/reflection/v1/ServerReflectionGrpc.java b/services/src/generated/main/grpc/io/grpc/reflection/v1/ServerReflectionGrpc.java index 4f2dce26486..04f8dea3ace 100644 --- a/services/src/generated/main/grpc/io/grpc/reflection/v1/ServerReflectionGrpc.java +++ b/services/src/generated/main/grpc/io/grpc/reflection/v1/ServerReflectionGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/reflection/v1/reflection.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ServerReflectionGrpc { @@ -60,6 +57,21 @@ public ServerReflectionStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return ServerReflectionStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ServerReflectionBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ServerReflectionBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionBlockingV2Stub(channel, callOptions); + } + }; + return ServerReflectionBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -149,6 +161,36 @@ public io.grpc.stub.StreamObserver { + private ServerReflectionBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ServerReflectionBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * The reflection service is structured as a bidirectional stream, ensuring
+     * all related requests go to a single server.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + serverReflectionInfo() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getServerReflectionInfoMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ServerReflection. + */ public static final class ServerReflectionBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ServerReflectionBlockingStub( diff --git a/services/src/generated/main/grpc/io/grpc/reflection/v1alpha/ServerReflectionGrpc.java b/services/src/generated/main/grpc/io/grpc/reflection/v1alpha/ServerReflectionGrpc.java index 7119e96d1f3..3cbb3a1d1b9 100644 --- a/services/src/generated/main/grpc/io/grpc/reflection/v1alpha/ServerReflectionGrpc.java +++ b/services/src/generated/main/grpc/io/grpc/reflection/v1alpha/ServerReflectionGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/reflection/v1alpha/reflection.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ServerReflectionGrpc { @@ -60,6 +57,21 @@ public ServerReflectionStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return ServerReflectionStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ServerReflectionBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ServerReflectionBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionBlockingV2Stub(channel, callOptions); + } + }; + return ServerReflectionBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -149,6 +161,36 @@ public io.grpc.stub.StreamObserver { + private ServerReflectionBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ServerReflectionBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * The reflection service is structured as a bidirectional stream, ensuring
+     * all related requests go to a single server.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + serverReflectionInfo() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getServerReflectionInfoMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ServerReflection. + */ public static final class ServerReflectionBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ServerReflectionBlockingStub( diff --git a/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherDynamicServiceGrpc.java b/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherDynamicServiceGrpc.java index 088d27b619c..978af2d887e 100644 --- a/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherDynamicServiceGrpc.java +++ b/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherDynamicServiceGrpc.java @@ -7,9 +7,6 @@ * AnotherDynamicService * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: io/grpc/reflection/testing/dynamic_reflection_test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class AnotherDynamicServiceGrpc { @@ -63,6 +60,21 @@ public AnotherDynamicServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOp return AnotherDynamicServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static AnotherDynamicServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AnotherDynamicServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AnotherDynamicServiceBlockingV2Stub(channel, callOptions); + } + }; + return AnotherDynamicServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -162,6 +174,36 @@ public void method(io.grpc.reflection.testing.DynamicRequest request, * AnotherDynamicService * */ + public static final class AnotherDynamicServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private AnotherDynamicServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AnotherDynamicServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AnotherDynamicServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * A method
+     * 
+ */ + public io.grpc.reflection.testing.DynamicReply method(io.grpc.reflection.testing.DynamicRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getMethodMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service AnotherDynamicService. + *
+   * AnotherDynamicService
+   * 
+ */ public static final class AnotherDynamicServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private AnotherDynamicServiceBlockingStub( diff --git a/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherReflectableServiceGrpc.java b/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherReflectableServiceGrpc.java index a84b95b2126..e688c3d5cca 100644 --- a/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherReflectableServiceGrpc.java +++ b/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherReflectableServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: io/grpc/reflection/testing/reflection_test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class AnotherReflectableServiceGrpc { @@ -60,6 +57,21 @@ public AnotherReflectableServiceStub newStub(io.grpc.Channel channel, io.grpc.Ca return AnotherReflectableServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static AnotherReflectableServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AnotherReflectableServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AnotherReflectableServiceBlockingV2Stub(channel, callOptions); + } + }; + return AnotherReflectableServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -141,6 +153,30 @@ public void method(io.grpc.reflection.testing.Request request, /** * A stub to allow clients to do synchronous rpc calls to service AnotherReflectableService. */ + public static final class AnotherReflectableServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private AnotherReflectableServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AnotherReflectableServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AnotherReflectableServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.reflection.testing.Reply method(io.grpc.reflection.testing.Request request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getMethodMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service AnotherReflectableService. + */ public static final class AnotherReflectableServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private AnotherReflectableServiceBlockingStub( diff --git a/services/src/generated/test/grpc/io/grpc/reflection/testing/DynamicServiceGrpc.java b/services/src/generated/test/grpc/io/grpc/reflection/testing/DynamicServiceGrpc.java index 338b67e684d..efef61be151 100644 --- a/services/src/generated/test/grpc/io/grpc/reflection/testing/DynamicServiceGrpc.java +++ b/services/src/generated/test/grpc/io/grpc/reflection/testing/DynamicServiceGrpc.java @@ -7,9 +7,6 @@ * A DynamicService * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: io/grpc/reflection/testing/dynamic_reflection_test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class DynamicServiceGrpc { @@ -63,6 +60,21 @@ public DynamicServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions c return DynamicServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static DynamicServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public DynamicServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new DynamicServiceBlockingV2Stub(channel, callOptions); + } + }; + return DynamicServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -162,6 +174,36 @@ public void method(io.grpc.reflection.testing.DynamicRequest request, * A DynamicService * */ + public static final class DynamicServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private DynamicServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected DynamicServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new DynamicServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * A method
+     * 
+ */ + public io.grpc.reflection.testing.DynamicReply method(io.grpc.reflection.testing.DynamicRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getMethodMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service DynamicService. + *
+   * A DynamicService
+   * 
+ */ public static final class DynamicServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private DynamicServiceBlockingStub( diff --git a/services/src/generated/test/grpc/io/grpc/reflection/testing/ReflectableServiceGrpc.java b/services/src/generated/test/grpc/io/grpc/reflection/testing/ReflectableServiceGrpc.java index 0b8954b5eb9..b5d130d6952 100644 --- a/services/src/generated/test/grpc/io/grpc/reflection/testing/ReflectableServiceGrpc.java +++ b/services/src/generated/test/grpc/io/grpc/reflection/testing/ReflectableServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: io/grpc/reflection/testing/reflection_test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ReflectableServiceGrpc { @@ -60,6 +57,21 @@ public ReflectableServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptio return ReflectableServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ReflectableServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ReflectableServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReflectableServiceBlockingV2Stub(channel, callOptions); + } + }; + return ReflectableServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -141,6 +153,30 @@ public void method(io.grpc.reflection.testing.Request request, /** * A stub to allow clients to do synchronous rpc calls to service ReflectableService. */ + public static final class ReflectableServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ReflectableServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ReflectableServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReflectableServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.reflection.testing.Reply method(io.grpc.reflection.testing.Request request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getMethodMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ReflectableService. + */ public static final class ReflectableServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ReflectableServiceBlockingStub( diff --git a/services/src/main/java/io/grpc/protobuf/services/ChannelzProtoUtil.java b/services/src/main/java/io/grpc/protobuf/services/ChannelzProtoUtil.java index cf003b2f881..74448a8c5bf 100644 --- a/services/src/main/java/io/grpc/protobuf/services/ChannelzProtoUtil.java +++ b/services/src/main/java/io/grpc/protobuf/services/ChannelzProtoUtil.java @@ -21,6 +21,7 @@ import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.Int64Value; +import com.google.protobuf.MessageLite; import com.google.protobuf.util.Durations; import com.google.protobuf.util.Timestamps; import io.grpc.ConnectivityState; @@ -79,6 +80,8 @@ /** * A static utility class for turning internal data structures into protos. + * + *

Works with both regular and lite protos. */ final class ChannelzProtoUtil { private static final Logger logger = Logger.getLogger(ChannelzProtoUtil.class.getName()); @@ -254,22 +257,20 @@ static SocketOption toSocketOptionLinger(int lingerSeconds) { } else { lingerOpt = SocketOptionLinger.getDefaultInstance(); } - return SocketOption - .newBuilder() + return SocketOption.newBuilder() .setName(SO_LINGER) - .setAdditional(Any.pack(lingerOpt)) + .setAdditional(packToAny("SocketOptionLinger", lingerOpt)) .build(); } static SocketOption toSocketOptionTimeout(String name, int timeoutMillis) { Preconditions.checkNotNull(name); - return SocketOption - .newBuilder() + return SocketOption.newBuilder() .setName(name) .setAdditional( - Any.pack( - SocketOptionTimeout - .newBuilder() + packToAny( + "SocketOptionTimeout", + SocketOptionTimeout.newBuilder() .setDuration(Durations.fromMillis(timeoutMillis)) .build())) .build(); @@ -307,10 +308,9 @@ static SocketOption toSocketOptionTcpInfo(InternalChannelz.TcpInfo i) { .setTcpiAdvmss(i.advmss) .setTcpiReordering(i.reordering) .build(); - return SocketOption - .newBuilder() + return SocketOption.newBuilder() .setName(TCP_INFO) - .setAdditional(Any.pack(tcpInfo)) + .setAdditional(packToAny("SocketOptionTcpInfo", tcpInfo)) .build(); } @@ -380,10 +380,11 @@ private static ChannelTrace toChannelTrace(InternalChannelz.ChannelTrace channel private static List toChannelTraceEvents(List events) { List channelTraceEvents = new ArrayList<>(); for (Event event : events) { - ChannelTraceEvent.Builder builder = ChannelTraceEvent.newBuilder() - .setDescription(event.description) - .setSeverity(Severity.valueOf(event.severity.name())) - .setTimestamp(Timestamps.fromNanos(event.timestampNanos)); + ChannelTraceEvent.Builder builder = + ChannelTraceEvent.newBuilder() + .setDescription(event.description) + .setSeverity(toSeverity(event.severity)) + .setTimestamp(Timestamps.fromNanos(event.timestampNanos)); if (event.channelRef != null) { builder.setChannelRef(toChannelRef(event.channelRef)); } @@ -395,14 +396,39 @@ private static List toChannelTraceEvents(List events) return Collections.unmodifiableList(channelTraceEvents); } + static Severity toSeverity(Event.Severity severity) { + if (severity == null) { + return Severity.CT_UNKNOWN; + } + switch (severity) { + case CT_INFO: + return Severity.CT_INFO; + case CT_ERROR: + return Severity.CT_ERROR; + case CT_WARNING: + return Severity.CT_WARNING; + default: + return Severity.CT_UNKNOWN; + } + } + static State toState(ConnectivityState state) { if (state == null) { return State.UNKNOWN; } - try { - return Enum.valueOf(State.class, state.name()); - } catch (IllegalArgumentException e) { - return State.UNKNOWN; + switch (state) { + case IDLE: + return State.IDLE; + case READY: + return State.READY; + case CONNECTING: + return State.CONNECTING; + case SHUTDOWN: + return State.SHUTDOWN; + case TRANSIENT_FAILURE: + return State.TRANSIENT_FAILURE; + default: + return State.UNKNOWN; } } @@ -468,4 +494,12 @@ private static T getFuture(ListenableFuture future) { throw Status.INTERNAL.withCause(e).asRuntimeException(); } } + + // A version of Any.pack() that works with protolite. + private static Any packToAny(String typeName, MessageLite value) { + return Any.newBuilder() + .setTypeUrl("type.googleapis.com/grpc.channelz.v1." + typeName) + .setValue(value.toByteString()) + .build(); + } } diff --git a/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java b/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java index cac522caf9e..b9f235d0aff 100644 --- a/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java +++ b/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java @@ -144,6 +144,30 @@ void setHealthCheckedService(@Nullable String service) { public String toString() { return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); } + + @Override + public void updateBalancingState( + io.grpc.ConnectivityState newState, LoadBalancer.SubchannelPicker newPicker) { + delegate().updateBalancingState(newState, new HealthCheckPicker(newPicker)); + } + + private final class HealthCheckPicker extends LoadBalancer.SubchannelPicker { + private final LoadBalancer.SubchannelPicker delegate; + + HealthCheckPicker(LoadBalancer.SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) { + LoadBalancer.PickResult result = delegate.pickSubchannel(args); + LoadBalancer.Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof SubchannelImpl) { + return result.copyWithSubchannel(((SubchannelImpl) subchannel).delegate()); + } + return result; + } + } } @VisibleForTesting @@ -194,7 +218,18 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { .get(LoadBalancer.ATTR_HEALTH_CHECKING_CONFIG); String serviceName = ServiceConfigUtil.getHealthCheckedServiceName(healthCheckingConfig); helper.setHealthCheckedService(serviceName); - super.handleResolvedAddresses(resolvedAddresses); + delegate.handleResolvedAddresses(resolvedAddresses); + } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + Map healthCheckingConfig = + resolvedAddresses + .getAttributes() + .get(LoadBalancer.ATTR_HEALTH_CHECKING_CONFIG); + String serviceName = ServiceConfigUtil.getHealthCheckedServiceName(healthCheckingConfig); + helper.setHealthCheckedService(serviceName); + return delegate.acceptResolvedAddresses(resolvedAddresses); } @Override diff --git a/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java b/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java index 6ce602b9295..5cd294b4fbe 100644 --- a/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java +++ b/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java @@ -18,6 +18,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Context; import io.grpc.Context.CancellationListener; import io.grpc.Status; @@ -26,6 +27,7 @@ import io.grpc.health.v1.HealthCheckResponse; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; import io.grpc.health.v1.HealthGrpc; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import java.util.HashMap; import java.util.IdentityHashMap; @@ -34,7 +36,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; final class HealthServiceImpl extends HealthGrpc.HealthImplBase { @@ -83,6 +84,11 @@ public void watch(HealthCheckRequest request, final StreamObserver responseObserver) { final String service = request.getService(); synchronized (watchLock) { + if (responseObserver instanceof ServerCallStreamObserver) { + ((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> { + removeWatcher(service, responseObserver); + }); + } ServingStatus status = statusMap.get(service); responseObserver.onNext(getResponseForWatch(status)); IdentityHashMap, Boolean> serviceWatchers = @@ -98,21 +104,25 @@ public void watch(HealthCheckRequest request, @Override // Called when the client has closed the stream public void cancelled(Context context) { - synchronized (watchLock) { - IdentityHashMap, Boolean> serviceWatchers = - watchers.get(service); - if (serviceWatchers != null) { - serviceWatchers.remove(responseObserver); - if (serviceWatchers.isEmpty()) { - watchers.remove(service); - } - } - } + removeWatcher(service, responseObserver); } }, MoreExecutors.directExecutor()); } + void removeWatcher(String service, StreamObserver responseObserver) { + synchronized (watchLock) { + IdentityHashMap, Boolean> serviceWatchers = + watchers.get(service); + if (serviceWatchers != null) { + serviceWatchers.remove(responseObserver); + if (serviceWatchers.isEmpty()) { + watchers.remove(service); + } + } + } + } + void setStatus(String service, ServingStatus status) { synchronized (watchLock) { if (terminal) { diff --git a/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java b/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java index 45947ed44ee..07008b682c3 100644 --- a/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java +++ b/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java @@ -28,11 +28,11 @@ /** * Provides a reflection service for Protobuf services (including the reflection service itself). - * Uses the deprecated v1alpha proto. New users should use ProtoReflectionServiceV1 instead. * *

Separately tracks mutable and immutable services. Throws an exception if either group of * services contains multiple Protobuf files with declarations of the same service, method, type, or * extension. + * Uses the deprecated v1alpha proto. New users should use {@link ProtoReflectionServiceV1} instead. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2222") public final class ProtoReflectionService implements BindableService { @@ -40,11 +40,13 @@ public final class ProtoReflectionService implements BindableService { private ProtoReflectionService() { } + @Deprecated public static BindableService newInstance() { return new ProtoReflectionService(); } @Override + @SuppressWarnings("deprecation") public ServerServiceDefinition bindService() { ServerServiceDefinition serverServiceDefinitionV1 = ProtoReflectionServiceV1.newInstance() .bindService(); diff --git a/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionServiceV1.java b/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionServiceV1.java index 578e9bbd409..59e9c33d279 100644 --- a/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionServiceV1.java +++ b/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionServiceV1.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.FileDescriptor; @@ -52,7 +53,6 @@ import java.util.Set; import java.util.WeakHashMap; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Provides a reflection service for Protobuf services (including the reflection service itself). diff --git a/services/src/main/proto/grpc/binlog/v1/binarylog.proto b/services/src/main/proto/grpc/binlog/v1/binarylog.proto index 9ed1733e2d8..b18bd88ddc9 100644 --- a/services/src/main/proto/grpc/binlog/v1/binarylog.proto +++ b/services/src/main/proto/grpc/binlog/v1/binarylog.proto @@ -120,7 +120,7 @@ message ClientHeader { // A single process may be used to run multiple virtual // servers with different identities. - // The authority is the name of such a server identitiy. + // The authority is the name of such a server identity. // It is typically a portion of the URI in the form of // or : . string authority = 3; diff --git a/services/src/main/proto/grpc/reflection/v1alpha/reflection.proto b/services/src/main/proto/grpc/reflection/v1alpha/reflection.proto index 8c5e06fe148..a3984b55c2d 100644 --- a/services/src/main/proto/grpc/reflection/v1alpha/reflection.proto +++ b/services/src/main/proto/grpc/reflection/v1alpha/reflection.proto @@ -80,7 +80,7 @@ message ExtensionRequest { message ServerReflectionResponse { string valid_host = 1; ServerReflectionRequest original_request = 2; - // The server set one of the following fields accroding to the message_request + // The server set one of the following fields according to the message_request // in the request. oneof message_response { // This message is used to answer file_by_filename, file_containing_symbol, @@ -91,7 +91,7 @@ message ServerReflectionResponse { // that were previously sent in response to earlier requests in the stream. FileDescriptorResponse file_descriptor_response = 4; - // This message is used to answer all_extension_numbers_of_type requst. + // This message is used to answer all_extension_numbers_of_type request. ExtensionNumberResponse all_extension_numbers_response = 5; // This message is used to answer list_services request. diff --git a/services/src/test/java/io/grpc/protobuf/services/ChannelzProtoUtilTest.java b/services/src/test/java/io/grpc/protobuf/services/ChannelzProtoUtilTest.java index 0d2e6063d5e..598a8625e58 100644 --- a/services/src/test/java/io/grpc/protobuf/services/ChannelzProtoUtilTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/ChannelzProtoUtilTest.java @@ -27,7 +27,7 @@ import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.Int64Value; -import com.google.protobuf.Message; +import com.google.protobuf.MessageLite; import com.google.protobuf.util.Durations; import com.google.protobuf.util.Timestamps; import io.grpc.ConnectivityState; @@ -154,33 +154,44 @@ public final class ChannelzProtoUtilTest { .setData(serverData) .build(); - private final SocketOption sockOptLingerDisabled = SocketOption - .newBuilder() - .setName("SO_LINGER") - .setAdditional( - Any.pack(SocketOptionLinger.getDefaultInstance())) - .build(); - - private final SocketOption sockOptlinger10s = SocketOption - .newBuilder() - .setName("SO_LINGER") - .setAdditional( - Any.pack(SocketOptionLinger - .newBuilder() - .setActive(true) - .setDuration(Durations.fromSeconds(10)) - .build())) - .build(); - - private final SocketOption sockOptTimeout200ms = SocketOption - .newBuilder() - .setName("SO_TIMEOUT") - .setAdditional( - Any.pack(SocketOptionTimeout - .newBuilder() - .setDuration(Durations.fromMillis(200)) - .build()) - ).build(); + private final SocketOption sockOptLingerDisabled = + SocketOption.newBuilder() + .setName("SO_LINGER") + .setAdditional( + Any.newBuilder() + .setTypeUrl("type.googleapis.com/grpc.channelz.v1.SocketOptionLinger") + .setValue(SocketOptionLinger.getDefaultInstance().toByteString()) + .build()) + .build(); + + private final SocketOption sockOptlinger10s = + SocketOption.newBuilder() + .setName("SO_LINGER") + .setAdditional( + Any.newBuilder() + .setTypeUrl("type.googleapis.com/grpc.channelz.v1.SocketOptionLinger") + .setValue( + SocketOptionLinger.newBuilder() + .setActive(true) + .setDuration(Durations.fromSeconds(10)) + .build() + .toByteString()) + .build()) + .build(); + + private final SocketOption sockOptTimeout200ms = + SocketOption.newBuilder() + .setName("SO_TIMEOUT") + .setAdditional( + Any.newBuilder() + .setTypeUrl("type.googleapis.com/grpc.channelz.v1.SocketOptionTimeout") + .setValue( + SocketOptionTimeout.newBuilder() + .setDuration(Durations.fromMillis(200)) + .build() + .toByteString()) + .build()) + .build(); private final SocketOption sockOptAdditional = SocketOption .newBuilder() @@ -221,43 +232,46 @@ public final class ChannelzProtoUtilTest { .setReordering(728) .build(); - private final SocketOption socketOptionTcpInfo = SocketOption - .newBuilder() - .setName("TCP_INFO") - .setAdditional( - Any.pack( - SocketOptionTcpInfo.newBuilder() - .setTcpiState(70) - .setTcpiCaState(71) - .setTcpiRetransmits(72) - .setTcpiProbes(73) - .setTcpiBackoff(74) - .setTcpiOptions(75) - .setTcpiSndWscale(76) - .setTcpiRcvWscale(77) - .setTcpiRto(78) - .setTcpiAto(79) - .setTcpiSndMss(710) - .setTcpiRcvMss(711) - .setTcpiUnacked(712) - .setTcpiSacked(713) - .setTcpiLost(714) - .setTcpiRetrans(715) - .setTcpiFackets(716) - .setTcpiLastDataSent(717) - .setTcpiLastAckSent(718) - .setTcpiLastDataRecv(719) - .setTcpiLastAckRecv(720) - .setTcpiPmtu(721) - .setTcpiRcvSsthresh(722) - .setTcpiRtt(723) - .setTcpiRttvar(724) - .setTcpiSndSsthresh(725) - .setTcpiSndCwnd(726) - .setTcpiAdvmss(727) - .setTcpiReordering(728) - .build())) - .build(); + private final SocketOption socketOptionTcpInfo = + SocketOption.newBuilder() + .setName("TCP_INFO") + .setAdditional( + Any.newBuilder() + .setTypeUrl("type.googleapis.com/grpc.channelz.v1.SocketOptionTcpInfo") + .setValue( + SocketOptionTcpInfo.newBuilder() + .setTcpiState(70) + .setTcpiCaState(71) + .setTcpiRetransmits(72) + .setTcpiProbes(73) + .setTcpiBackoff(74) + .setTcpiOptions(75) + .setTcpiSndWscale(76) + .setTcpiRcvWscale(77) + .setTcpiRto(78) + .setTcpiAto(79) + .setTcpiSndMss(710) + .setTcpiRcvMss(711) + .setTcpiUnacked(712) + .setTcpiSacked(713) + .setTcpiLost(714) + .setTcpiRetrans(715) + .setTcpiFackets(716) + .setTcpiLastDataSent(717) + .setTcpiLastAckSent(718) + .setTcpiLastDataRecv(719) + .setTcpiLastAckRecv(720) + .setTcpiPmtu(721) + .setTcpiRcvSsthresh(722) + .setTcpiRtt(723) + .setTcpiRttvar(724) + .setTcpiSndSsthresh(725) + .setTcpiSndCwnd(726) + .setTcpiAdvmss(727) + .setTcpiReordering(728) + .build() + .toByteString())) + .build(); private final TestListenSocket listenSocket = new TestListenSocket(); private final SocketRef listenSocketRef = SocketRef @@ -336,6 +350,16 @@ public void toServerRef() { assertEquals(serverRef, ChannelzProtoUtil.toServerRef(server)); } + @Test + public void toSeverity() { + for (Severity severity : Severity.values()) { + assertEquals( + severity.name(), + ChannelzProtoUtil.toSeverity(severity).name()); // OK because test isn't proguarded. + } + assertEquals(ChannelTraceEvent.Severity.CT_UNKNOWN, ChannelzProtoUtil.toSeverity(null)); + } + @Test public void toSocketRef() { assertEquals(socketRef, ChannelzProtoUtil.toSocketRef(socket)); @@ -346,7 +370,7 @@ public void toState() { for (ConnectivityState connectivityState : ConnectivityState.values()) { assertEquals( connectivityState.name(), - ChannelzProtoUtil.toState(connectivityState).getValueDescriptor().getName()); + ChannelzProtoUtil.toState(connectivityState).name()); // OK because test isn't proguarded. } assertEquals(State.UNKNOWN, ChannelzProtoUtil.toState(null)); } @@ -475,8 +499,12 @@ public void socketSecurityTls() throws Exception { @Test public void socketSecurityOther() throws Exception { // what is packed here is not important, just pick some proto message - Message contents = GetChannelRequest.newBuilder().setChannelId(1).build(); - Any packed = Any.pack(contents); + MessageLite contents = GetChannelRequest.newBuilder().setChannelId(1).build(); + Any packed = + Any.newBuilder() + .setTypeUrl("type.googleapis.com/grpc.channelz.v1.GetChannelRequest") + .setValue(contents.toByteString()) + .build(); socket.security = new InternalChannelz.Security( new InternalChannelz.OtherSecurity("other_security", packed)); diff --git a/services/src/test/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactoryTest.java b/services/src/test/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactoryTest.java index 08a33106fb9..a49c426f7e1 100644 --- a/services/src/test/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactoryTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactoryTest.java @@ -206,15 +206,16 @@ public void setup() throws Exception { boolean shutdown; @Override - public void handleResolvedAddresses(final ResolvedAddresses resolvedAddresses) { + public Status acceptResolvedAddresses(final ResolvedAddresses resolvedAddresses) { syncContext.execute(new Runnable() { @Override public void run() { if (!shutdown) { - hcLb.handleResolvedAddresses(resolvedAddresses); + hcLb.acceptResolvedAddresses(resolvedAddresses); } } }); + return Status.OK; } @Override @@ -264,9 +265,9 @@ public void typicalWorkflow() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); verify(origHelper, atLeast(0)).getSynchronizationContext(); verify(origHelper, atLeast(0)).getScheduledExecutorService(); verifyNoMoreInteractions(origHelper); @@ -404,9 +405,9 @@ public void healthCheckDisabledWhenServiceNotImplemented() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); verifyNoMoreInteractions(origLb); // We create 2 Subchannels. One of them connects to a server that doesn't implement health check @@ -489,9 +490,9 @@ public void backoffRetriesWhenServerErroneouslyClosesRpcBeforeAnyResponse() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); verifyNoMoreInteractions(origLb); SubchannelStateListener mockHealthListener = mockHealthListeners[0]; @@ -567,9 +568,9 @@ public void serverRespondResetsBackoff() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); verifyNoMoreInteractions(origLb); SubchannelStateListener mockStateListener = mockStateListeners[0]; @@ -667,9 +668,9 @@ public void serviceConfigHasNoHealthCheckingInitiallyButDoesLater() { .setAddresses(resolvedAddressList) .setAttributes(Attributes.EMPTY) .build(); - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); - verify(origLb).handleResolvedAddresses(result1); + verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb); // First, create Subchannels 0 @@ -688,8 +689,8 @@ public void serviceConfigHasNoHealthCheckingInitiallyButDoesLater() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); - verify(origLb).handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); + verify(origLb).acceptResolvedAddresses(result2); // Health check started on existing Subchannel assertThat(healthImpls[0].calls).hasSize(1); @@ -711,9 +712,9 @@ public void serviceConfigDisablesHealthCheckWhenRpcActive() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); - verify(origLb).handleResolvedAddresses(result1); + verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb); Subchannel subchannel = createSubchannel(0, Attributes.EMPTY, maybeGetMockListener()); @@ -738,7 +739,7 @@ public void serviceConfigDisablesHealthCheckWhenRpcActive() { .setAddresses(resolvedAddressList) .setAttributes(Attributes.EMPTY) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); // Health check RPC cancelled. assertThat(serverCall.cancelled).isTrue(); @@ -746,7 +747,7 @@ public void serviceConfigDisablesHealthCheckWhenRpcActive() { inOrder.verify(getMockListener()).onSubchannelState( eq(ConnectivityStateInfo.forNonError(READY))); - inOrder.verify(origLb).handleResolvedAddresses(result2); + inOrder.verify(origLb).acceptResolvedAddresses(result2); verifyNoMoreInteractions(origLb, mockStateListeners[0]); assertThat(healthImpl.calls).isEmpty(); @@ -759,9 +760,9 @@ public void serviceConfigDisablesHealthCheckWhenRetryPending() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); verifyNoMoreInteractions(origLb); SubchannelStateListener mockHealthListener = mockHealthListeners[0]; @@ -793,7 +794,7 @@ public void serviceConfigDisablesHealthCheckWhenRetryPending() { .setAddresses(resolvedAddressList) .setAttributes(Attributes.EMPTY) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); // Retry timer is cancelled assertThat(clock.getPendingTasks()).isEmpty(); @@ -805,7 +806,7 @@ public void serviceConfigDisablesHealthCheckWhenRetryPending() { inOrder.verify(getMockListener()).onSubchannelState( eq(ConnectivityStateInfo.forNonError(READY))); - inOrder.verify(origLb).handleResolvedAddresses(result2); + inOrder.verify(origLb).acceptResolvedAddresses(result2); verifyNoMoreInteractions(origLb, mockStateListeners[0]); } @@ -817,9 +818,9 @@ public void serviceConfigDisablesHealthCheckWhenRpcInactive() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); - verify(origLb).handleResolvedAddresses(result1); + verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb); Subchannel subchannel = createSubchannel(0, Attributes.EMPTY, maybeGetMockListener()); @@ -842,9 +843,9 @@ public void serviceConfigDisablesHealthCheckWhenRpcInactive() { .setAddresses(resolvedAddressList) .setAttributes(Attributes.EMPTY) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); - inOrder.verify(origLb).handleResolvedAddresses(result2); + inOrder.verify(origLb).acceptResolvedAddresses(result2); // Underlying subchannel is now ready deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); @@ -870,9 +871,9 @@ public void serviceConfigChangesServiceNameWhenRpcActive() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); - verify(origLb).handleResolvedAddresses(result1); + verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb); SubchannelStateListener mockHealthListener = mockHealthListeners[0]; @@ -900,9 +901,9 @@ public void serviceConfigChangesServiceNameWhenRpcActive() { eq(ConnectivityStateInfo.forNonError(READY))); // Service config returns with the same health check name. - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); // It's delivered to origLb, but nothing else happens - inOrder.verify(origLb).handleResolvedAddresses(result1); + inOrder.verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb, mockListener); // Service config returns a different health check name. @@ -911,8 +912,8 @@ public void serviceConfigChangesServiceNameWhenRpcActive() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); - inOrder.verify(origLb).handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); + inOrder.verify(origLb).acceptResolvedAddresses(result2); // Current health check RPC cancelled. assertThat(serverCall.cancelled).isTrue(); @@ -934,9 +935,9 @@ public void serviceConfigChangesServiceNameWhenRetryPending() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); - verify(origLb).handleResolvedAddresses(result1); + verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb); SubchannelStateListener mockHealthListener = mockHealthListeners[0]; @@ -969,9 +970,9 @@ public void serviceConfigChangesServiceNameWhenRetryPending() { // Service config returns with the same health check name. - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); // It's delivered to origLb, but nothing else happens - inOrder.verify(origLb).handleResolvedAddresses(result1); + inOrder.verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb, mockListener); assertThat(clock.getPendingTasks()).hasSize(1); assertThat(healthImpl.calls).isEmpty(); @@ -982,12 +983,12 @@ public void serviceConfigChangesServiceNameWhenRetryPending() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); // Concluded CONNECTING state inOrder.verify(getMockListener()).onSubchannelState( eq(ConnectivityStateInfo.forNonError(CONNECTING))); - inOrder.verify(origLb).handleResolvedAddresses(result2); + inOrder.verify(origLb).acceptResolvedAddresses(result2); // Current retry timer cancelled assertThat(clock.getPendingTasks()).isEmpty(); @@ -1008,9 +1009,9 @@ public void serviceConfigChangesServiceNameWhenRpcInactive() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); - verify(origLb).handleResolvedAddresses(result1); + verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb); Subchannel subchannel = createSubchannel(0, Attributes.EMPTY, maybeGetMockListener()); @@ -1031,9 +1032,9 @@ public void serviceConfigChangesServiceNameWhenRpcInactive() { inOrder.verifyNoMoreInteractions(); // Service config returns with the same health check name. - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); // It's delivered to origLb, but nothing else happens - inOrder.verify(origLb).handleResolvedAddresses(result1); + inOrder.verify(origLb).acceptResolvedAddresses(result1); assertThat(healthImpl.calls).isEmpty(); verifyNoMoreInteractions(origLb); @@ -1043,9 +1044,9 @@ public void serviceConfigChangesServiceNameWhenRpcInactive() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); - inOrder.verify(origLb).handleResolvedAddresses(result2); + inOrder.verify(origLb).acceptResolvedAddresses(result2); // Underlying subchannel is now ready deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); @@ -1092,9 +1093,9 @@ public void balancerShutdown() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); verifyNoMoreInteractions(origLb); ServerSideCall[] serverCalls = new ServerSideCall[NUM_SUBCHANNELS]; @@ -1172,8 +1173,8 @@ public LoadBalancer newLoadBalancer(Helper helper) { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); createSubchannel(0, Attributes.EMPTY); assertThat(healthImpls[0].calls).isEmpty(); deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); diff --git a/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java b/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java index 87d4ac29be8..b2652e92771 100644 --- a/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java @@ -18,6 +18,11 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import io.grpc.BindableService; import io.grpc.Context; @@ -28,6 +33,7 @@ import io.grpc.health.v1.HealthCheckResponse; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; import io.grpc.health.v1.HealthGrpc; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcServerRule; import java.util.ArrayDeque; @@ -109,6 +115,18 @@ public void enterTerminalState_watch() throws Exception { assertThat(obs.responses).isEmpty(); } + @Test + @SuppressWarnings("unchecked") + public void serverCallStreamObserver_watch() throws Exception { + manager.setStatus(SERVICE1, ServingStatus.SERVING); + ServerCallStreamObserver observer = mock(ServerCallStreamObserver.class); + service.watch(HealthCheckRequest.newBuilder().setService(SERVICE1).build(), observer); + + verify(observer, times(1)) + .onNext(eq(HealthCheckResponse.newBuilder().setStatus(ServingStatus.SERVING).build())); + verify(observer, times(1)).setOnCancelHandler(any(Runnable.class)); + } + @Test public void enterTerminalState_ignoreClear() throws Exception { manager.setStatus(SERVICE1, ServingStatus.SERVING); diff --git a/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceTest.java b/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceTest.java index c9dd1014141..115dd11b0f1 100644 --- a/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceTest.java @@ -71,7 +71,8 @@ public class ProtoReflectionServiceTest { private static final String TEST_HOST = "localhost"; private MutableHandlerRegistry handlerRegistry = new MutableHandlerRegistry(); - private BindableService reflectionService; + @SuppressWarnings("deprecation") + private BindableService reflectionService = ProtoReflectionService.newInstance(); private ServerServiceDefinition dynamicService = new DynamicServiceGrpc.DynamicServiceImplBase() {}.bindService(); private ServerServiceDefinition anotherDynamicService = @@ -80,7 +81,6 @@ public class ProtoReflectionServiceTest { @Before public void setUp() throws Exception { - reflectionService = ProtoReflectionService.newInstance(); Server server = InProcessServerBuilder.forName("proto-reflection-test") .directExecutor() diff --git a/servlet/build.gradle b/servlet/build.gradle index fd5abb6f0e5..1367a72ab44 100644 --- a/servlet/build.gradle +++ b/servlet/build.gradle @@ -34,8 +34,7 @@ tasks.named("jar").configure { dependencies { api project(':grpc-api') - compileOnly libraries.javax.servlet.api, - libraries.javax.annotation // java 9, 10 needs it + compileOnly libraries.javax.servlet.api implementation project(':grpc-core'), libraries.guava @@ -43,7 +42,7 @@ dependencies { testImplementation libraries.javax.servlet.api threadingTestImplementation project(':grpc-servlet'), - libraries.truth, + libraries.junit, libraries.javax.servlet.api, libraries.lincheck @@ -69,19 +68,12 @@ dependencies { libraries.protobuf.java } -tasks.named("test").configure { - if (JavaVersion.current().isJava9Compatible()) { - jvmArgs += [ - // required for Lincheck - '--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED', - '--add-exports=java.base/jdk.internal.util=ALL-UNNAMED', - ] - } -} - tasks.register('threadingTest', Test) { classpath = sourceSets.threadingTest.runtimeClasspath testClassesDirs = sourceSets.threadingTest.output.classesDirs + jacoco { + enabled = false + } } tasks.named("assemble").configure { diff --git a/servlet/jakarta/build.gradle b/servlet/jakarta/build.gradle index 51333856ddf..bcd904ccaee 100644 --- a/servlet/jakarta/build.gradle +++ b/servlet/jakarta/build.gradle @@ -8,13 +8,15 @@ description = "gRPC: Jakarta Servlet" // Set up classpaths and source directories for different servlet tests sourceSets { - // Only run these tests if java 11+ is being used - if (JavaVersion.current().isJava11Compatible()) { + // Only run these tests if the required minimum Java version is being used + if (JavaVersion.current().isCompatibleWith(JavaVersion.VERSION_17)) { jettyTest { java { include '**/Jetty*.java' } } + } + if (JavaVersion.current().isJava11Compatible()) { tomcatTest { java { include '**/Tomcat*.java' @@ -45,11 +47,15 @@ def migrate(String name, String inputDir, SourceSet sourceSet) { def outputDir = layout.buildDirectory.dir('generated/sources/jakarta-' + name) sourceSet.java.srcDir tasks.register('migrateSources' + name.capitalize(), Sync) { task -> into(outputDir) + // Increment when changing the filter, to inform Gradle it needs to rebuild + inputs.property("filter-version", "1") from("$inputDir/io/grpc/servlet") { into('io/grpc/servlet/jakarta') filter { String line -> line.replace('javax.servlet', 'jakarta.servlet') .replace('io.grpc.servlet', 'io.grpc.servlet.jakarta') + .replace('org.eclipse.jetty.http2.parser', 'org.eclipse.jetty.http2') + .replace('org.eclipse.jetty.servlet', 'org.eclipse.jetty.ee10.servlet') } } } @@ -57,9 +63,11 @@ def migrate(String name, String inputDir, SourceSet sourceSet) { migrate('main', '../src/main/java', sourceSets.main) -// Only build sourceSets and classpaths for tests if using Java 11 -if (JavaVersion.current().isJava11Compatible()) { +// Only build sourceSets and classpaths for tests if using the required minimum Java version +if (JavaVersion.current().isCompatibleWith(JavaVersion.VERSION_17)) { migrate('jettyTest', '../src/jettyTest/java', sourceSets.jettyTest) +} +if (JavaVersion.current().isJava11Compatible()) { migrate('tomcatTest', '../src/tomcatTest/java', sourceSets.tomcatTest) migrate('undertowTest', '../src/undertowTest/java', sourceSets.undertowTest) } @@ -77,8 +85,7 @@ tasks.named("jar").configure { dependencies { api project(':grpc-api') - compileOnly libraries.jakarta.servlet.api, - libraries.javax.annotation + compileOnly libraries.jakarta.servlet.api implementation project(':grpc-util'), project(':grpc-core'), @@ -104,12 +111,19 @@ dependencies { // Set up individual classpaths for each test, to avoid any mismatch, // and ensure they are only used when supported by the current jvm -if (JavaVersion.current().isJava11Compatible()) { +if (JavaVersion.current().isCompatibleWith(JavaVersion.VERSION_17)) { def jetty11Test = tasks.register('jetty11Test', Test) { classpath = sourceSets.jettyTest.runtimeClasspath testClassesDirs = sourceSets.jettyTest.output.classesDirs } - + tasks.named('compileJettyTestJava') { JavaCompile task -> + task.options.release.set 9 + } + tasks.named("check").configure { + dependsOn jetty11Test + } +} +if (JavaVersion.current().isJava11Compatible()) { def tomcat10Test = tasks.register('tomcat10Test', Test) { classpath = sourceSets.tomcatTest.runtimeClasspath testClassesDirs = sourceSets.tomcatTest.output.classesDirs @@ -134,6 +148,6 @@ if (JavaVersion.current().isJava11Compatible()) { } tasks.named("check").configure { - dependsOn jetty11Test, tomcat10Test, undertowTest + dependsOn tomcat10Test, undertowTest } } diff --git a/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java b/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java index e9cb391ea08..58143a8516c 100644 --- a/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java +++ b/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java @@ -69,6 +69,7 @@ public void start(ServerListener listener) throws IOException { listener.transportCreated(new ServletServerBuilder.ServerTransportImpl(scheduler)); ServletAdapter adapter = new ServletAdapter(serverTransportListener, streamTracerFactories, + ServletAdapter.DEFAULT_METHOD_NAME_RESOLVER, Integer.MAX_VALUE); GrpcServlet grpcServlet = new GrpcServlet(adapter); @@ -76,9 +77,7 @@ public void start(ServerListener listener) throws IOException { ServerConnector sc = (ServerConnector) jettyServer.getConnectors()[0]; HttpConfiguration httpConfiguration = new HttpConfiguration(); - // Must be set for several tests to pass, so that the request handling can begin before - // content arrives. - httpConfiguration.setDelayDispatchUntilContent(false); + setDelayDispatchUntilContent(httpConfiguration); HTTP2CServerConnectionFactory factory = new HTTP2CServerConnectionFactory(httpConfiguration); @@ -134,6 +133,16 @@ protected InternalServer newServer(int port, return newServer(streamTracerFactories); } + // The future default appears to be false as people are supposed to be migrate to + // EagerContentHandler, but the default is still true. Seems they messed up the migration + // process here by not flipping the default. + @SuppressWarnings("removal") + private static void setDelayDispatchUntilContent(HttpConfiguration httpConfiguration) { + // Must be set for several tests to pass, so that the request handling can begin before + // content arrives. + httpConfiguration.setDelayDispatchUntilContent(false); + } + @Override protected ManagedClientTransport newClientTransport(InternalServer server) { NettyChannelBuilder nettyChannelBuilder = NettyChannelBuilder diff --git a/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java b/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java index cfd29b1a2fd..3c8d3d07571 100644 --- a/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java +++ b/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java @@ -22,19 +22,19 @@ import static java.util.logging.Level.FINEST; import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.InternalLogId; import io.grpc.servlet.ServletServerStream.ServletTransportState; import java.io.IOException; -import java.time.Duration; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.LockSupport; import java.util.function.BiFunction; import java.util.function.BooleanSupplier; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.servlet.AsyncContext; import javax.servlet.ServletOutputStream; @@ -128,7 +128,7 @@ public void finest(String str, Object... params) { log.fine("call completed"); }); }; - this.isReady = () -> outputStream.isReady(); + this.isReady = outputStream::isReady; } /** @@ -173,7 +173,9 @@ void complete() { /** Called from the container thread {@link javax.servlet.WriteListener#onWritePossible()}. */ void onWritePossible() throws IOException { log.finest("onWritePossible: ENTRY. The servlet output stream becomes ready"); - assureReadyAndDrainedTurnsFalse(); + if (writeState.get().readyAndDrained) { + assureReadyAndDrainedTurnsFalse(); + } while (isReady.getAsBoolean()) { WriteState curState = writeState.get(); @@ -200,11 +202,9 @@ private void assureReadyAndDrainedTurnsFalse() { // readyAndDrained should have been set to false already. // Just in case due to a race condition readyAndDrained is still true at this moment and is // being set to false by runOrBuffer() concurrently. + parkingThread = Thread.currentThread(); while (writeState.get().readyAndDrained) { - parkingThread = Thread.currentThread(); - // Try to sleep for an extremely long time to avoid writeState being changed at exactly - // the time when sleep time expires (in extreme scenario, such as #9917). - LockSupport.parkNanos(Duration.ofHours(1).toNanos()); // should return immediately + LockSupport.parkNanos(TimeUnit.MINUTES.toNanos(1)); // should return immediately } parkingThread = null; } @@ -254,7 +254,7 @@ interface ActionItem { @VisibleForTesting // Lincheck test can not run with java.util.logging dependency. interface Log { default boolean isLoggable(Level level) { - return false; + return false; } default void fine(String str, Object...params) {} diff --git a/servlet/src/main/java/io/grpc/servlet/GrpcServlet.java b/servlet/src/main/java/io/grpc/servlet/GrpcServlet.java index f68ed083506..8c1eb858ad1 100644 --- a/servlet/src/main/java/io/grpc/servlet/GrpcServlet.java +++ b/servlet/src/main/java/io/grpc/servlet/GrpcServlet.java @@ -37,6 +37,7 @@ public class GrpcServlet extends HttpServlet { private static final long serialVersionUID = 1L; + @SuppressWarnings("serial") private final ServletAdapter servletAdapter; GrpcServlet(ServletAdapter servletAdapter) { diff --git a/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java b/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java index 5a567916f99..668e82425cb 100644 --- a/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java +++ b/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java @@ -22,6 +22,7 @@ import static java.util.logging.Level.FINE; import static java.util.logging.Level.FINEST; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.BaseEncoding; import io.grpc.Attributes; import io.grpc.ExperimentalApi; @@ -45,6 +46,7 @@ import java.util.Enumeration; import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.function.Function; import java.util.logging.Logger; import javax.servlet.AsyncContext; import javax.servlet.AsyncEvent; @@ -72,18 +74,23 @@ public final class ServletAdapter { static final Logger logger = Logger.getLogger(ServletAdapter.class.getName()); + static final Function DEFAULT_METHOD_NAME_RESOLVER = + req -> req.getRequestURI().substring(1); // remove the leading "/" private final ServerTransportListener transportListener; private final List streamTracerFactories; + private final Function methodNameResolver; private final int maxInboundMessageSize; private final Attributes attributes; ServletAdapter( ServerTransportListener transportListener, List streamTracerFactories, + Function methodNameResolver, int maxInboundMessageSize) { this.transportListener = transportListener; this.streamTracerFactories = streamTracerFactories; + this.methodNameResolver = methodNameResolver; this.maxInboundMessageSize = maxInboundMessageSize; attributes = transportListener.transportReady(Attributes.EMPTY); } @@ -119,7 +126,7 @@ public void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOEx AsyncContext asyncCtx = req.startAsync(req, resp); - String method = req.getRequestURI().substring(1); // remove the leading "/" + String method = methodNameResolver.apply(req); Metadata headers = getHeaders(req); if (logger.isLoggable(FINEST)) { @@ -128,10 +135,9 @@ public void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOEx } Long timeoutNanos = headers.get(TIMEOUT_KEY); - if (timeoutNanos == null) { - timeoutNanos = 0L; - } - asyncCtx.setTimeout(TimeUnit.NANOSECONDS.toMillis(timeoutNanos)); + asyncCtx.setTimeout(timeoutNanos != null + ? TimeUnit.NANOSECONDS.toMillis(timeoutNanos) + ASYNC_TIMEOUT_SAFETY_MARGIN + : 0); StatsTraceContext statsTraceCtx = StatsTraceContext.newServerContext(streamTracerFactories, method, headers); @@ -158,6 +164,12 @@ public void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOEx asyncCtx.addListener(new GrpcAsyncListener(stream, logId)); } + /** + * Deadlines are managed via Context, servlet async timeout is not supposed to happen. + */ + @VisibleForTesting + static final long ASYNC_TIMEOUT_SAFETY_MARGIN = 5_000; + // This method must use Enumeration and its members, since that is the only way to read headers // from the servlet api. @SuppressWarnings("JdkObsolete") @@ -215,7 +227,9 @@ private static final class GrpcAsyncListener implements AsyncListener { } @Override - public void onComplete(AsyncEvent event) {} + public void onComplete(AsyncEvent event) { + stream.asyncCompleted = true; + } @Override public void onTimeout(AsyncEvent event) { diff --git a/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java b/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java index 72c4383d273..5bea4c6e03b 100644 --- a/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java +++ b/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java @@ -49,8 +49,10 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Function; import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; +import javax.servlet.http.HttpServletRequest; /** * Builder to build a gRPC server that can run as a servlet. This is for advanced custom settings. @@ -64,6 +66,8 @@ @NotThreadSafe public final class ServletServerBuilder extends ForwardingServerBuilder { List streamTracerFactories; + private Function methodNameResolver = + ServletAdapter.DEFAULT_METHOD_NAME_RESOLVER; int maxInboundMessageSize = DEFAULT_MAX_MESSAGE_SIZE; private final ServerImplBuilder serverImplBuilder; @@ -74,7 +78,9 @@ public final class ServletServerBuilder extends ForwardingServerBuilder + buildTransportServers(streamTracerFactories)); } /** @@ -98,7 +104,8 @@ public Server build() { * Creates a {@link ServletAdapter}. */ public ServletAdapter buildServletAdapter() { - return new ServletAdapter(buildAndStart(), streamTracerFactories, maxInboundMessageSize); + return new ServletAdapter(buildAndStart(), streamTracerFactories, methodNameResolver, + maxInboundMessageSize); } /** @@ -176,6 +183,18 @@ public ServletServerBuilder useTransportSecurity(File certChain, File privateKey throw new UnsupportedOperationException("TLS should be configured by the servlet container"); } + /** + * Specifies how to determine gRPC method name from servlet request. + * + *

The default strategy is using {@link HttpServletRequest#getRequestURI()} without the leading + * slash.

+ */ + public ServletServerBuilder methodNameResolver( + Function methodResolver) { + this.methodNameResolver = checkNotNull(methodResolver); + return this; + } + @Override public ServletServerBuilder maxInboundMessageSize(int bytes) { checkArgument(bytes >= 0, "bytes must be >= 0"); diff --git a/servlet/src/main/java/io/grpc/servlet/ServletServerStream.java b/servlet/src/main/java/io/grpc/servlet/ServletServerStream.java index b7ad6e0decc..0182f302698 100644 --- a/servlet/src/main/java/io/grpc/servlet/ServletServerStream.java +++ b/servlet/src/main/java/io/grpc/servlet/ServletServerStream.java @@ -30,7 +30,6 @@ import io.grpc.InternalLogId; import io.grpc.Metadata; import io.grpc.Status; -import io.grpc.Status.Code; import io.grpc.internal.AbstractServerStream; import io.grpc.internal.GrpcUtil; import io.grpc.internal.SerializingExecutor; @@ -43,8 +42,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; import java.util.function.Supplier; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -58,12 +56,15 @@ final class ServletServerStream extends AbstractServerStream { private final ServletTransportState transportState; private final Sink sink = new Sink(); - private final AsyncContext asyncCtx; private final HttpServletResponse resp; private final Attributes attributes; private final String authority; private final InternalLogId logId; private final AsyncServletOutputStreamWriter writer; + /** + * If the async servlet operation has been completed. + */ + volatile boolean asyncCompleted = false; ServletServerStream( AsyncContext asyncCtx, @@ -78,7 +79,6 @@ final class ServletServerStream extends AbstractServerStream { this.attributes = attributes; this.authority = authority; this.logId = logId; - this.asyncCtx = asyncCtx; this.resp = (HttpServletResponse) asyncCtx.getResponse(); this.writer = new AsyncServletOutputStreamWriter( asyncCtx, transportState, logId); @@ -123,9 +123,13 @@ private void writeHeadersToServletResponse(Metadata metadata) { resp.setStatus(HttpServletResponse.SC_OK); resp.setContentType(CONTENT_TYPE_GRPC); + serializeHeaders(metadata, resp::addHeader); + } + + private static void serializeHeaders(Metadata metadata, BiConsumer consumer) { byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(metadata); for (int i = 0; i < serializedHeaders.length; i += 2) { - resp.addHeader( + consumer.accept( new String(serializedHeaders[i], StandardCharsets.US_ASCII), new String(serializedHeaders[i + 1], StandardCharsets.US_ASCII)); } @@ -154,8 +158,8 @@ public void bytesRead(int numBytes) { @Override public void deframeFailed(Throwable cause) { - if (logger.isLoggable(FINE)) { - logger.log(FINE, String.format("[{%s}] Exception processing message", logId), cause); + if (logger.isLoggable(WARNING)) { + logger.log(WARNING, String.format("[{%s}] Exception processing message", logId), cause); } cancel(Status.fromThrowable(cause)); } @@ -168,7 +172,7 @@ private static final class ByteArrayWritableBuffer implements WritableBuffer { private int index; ByteArrayWritableBuffer(int capacityHint) { - this.bytes = new byte[min(1024 * 1024, max(4096, capacityHint))]; + this.bytes = new byte[min(1024 * 1024, capacityHint)]; this.capacity = bytes.length; } @@ -278,13 +282,8 @@ public void writeTrailers(Metadata trailers, boolean headersSent, Status status) if (!headersSent) { writeHeadersToServletResponse(trailers); } else { - byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(trailers); - for (int i = 0; i < serializedHeaders.length; i += 2) { - String key = new String(serializedHeaders[i], StandardCharsets.US_ASCII); - String newValue = new String(serializedHeaders[i + 1], StandardCharsets.US_ASCII); - trailerSupplier.get().computeIfPresent(key, (k, v) -> v + "," + newValue); - trailerSupplier.get().putIfAbsent(key, newValue); - } + serializeHeaders(trailers, + (k, v) -> trailerSupplier.get().merge(k, v, (oldV, newV) -> oldV + "," + newV)); } writer.complete(); @@ -292,22 +291,14 @@ public void writeTrailers(Metadata trailers, boolean headersSent, Status status) @Override public void cancel(Status status) { - if (resp.isCommitted() && Code.DEADLINE_EXCEEDED == status.getCode()) { - return; // let the servlet timeout, the container will sent RST_STREAM automatically - } transportState.runOnTransportThread(() -> transportState.transportReportStatus(status)); - // There is no way to RST_STREAM with CANCEL code, so write trailers instead - close(Status.CANCELLED.withCause(status.asRuntimeException()), new Metadata()); - CountDownLatch countDownLatch = new CountDownLatch(1); - transportState.runOnTransportThread(() -> { - asyncCtx.complete(); - countDownLatch.countDown(); - }); - try { - countDownLatch.await(5, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); + if (asyncCompleted) { + logger.fine("ignore cancel as already completed"); + return; } + // There is no way to RST_STREAM with CANCEL code, so write trailers instead + close(status, new Metadata()); + // close() calls writeTrailers(), which calls AsyncContext.complete() } } diff --git a/servlet/src/test/java/io/grpc/servlet/ServletServerBuilderTest.java b/servlet/src/test/java/io/grpc/servlet/ServletServerBuilderTest.java index d571cfd45d5..7a8c5b91f25 100644 --- a/servlet/src/test/java/io/grpc/servlet/ServletServerBuilderTest.java +++ b/servlet/src/test/java/io/grpc/servlet/ServletServerBuilderTest.java @@ -80,7 +80,7 @@ public void scheduledExecutorService() throws Exception { ServletAdapter servletAdapter = serverBuilder.buildServletAdapter(); servletAdapter.doPost(request, response); - verify(asyncContext).setTimeout(1); + verify(asyncContext).setTimeout(1 + ServletAdapter.ASYNC_TIMEOUT_SAFETY_MARGIN); // The following just verifies that scheduler is populated to the transport. // It doesn't matter what tasks (such as handshake timeout and request deadline) are actually diff --git a/servlet/src/threadingTest/java/io/grpc/servlet/AsyncServletOutputStreamWriterConcurrencyTest.java b/servlet/src/threadingTest/java/io/grpc/servlet/AsyncServletOutputStreamWriterConcurrencyTest.java index 61da2bf4c69..b2891b6e47e 100644 --- a/servlet/src/threadingTest/java/io/grpc/servlet/AsyncServletOutputStreamWriterConcurrencyTest.java +++ b/servlet/src/threadingTest/java/io/grpc/servlet/AsyncServletOutputStreamWriterConcurrencyTest.java @@ -16,23 +16,22 @@ package io.grpc.servlet; -import static com.google.common.truth.Truth.assertWithMessage; -import static org.jetbrains.kotlinx.lincheck.strategy.managed.ManagedStrategyGuaranteeKt.forClasses; +import static org.jetbrains.lincheck.datastructures.ManagedStrategyGuaranteeKt.forClasses; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import io.grpc.servlet.AsyncServletOutputStreamWriter.ActionItem; import io.grpc.servlet.AsyncServletOutputStreamWriter.Log; import java.io.IOException; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; -import org.jetbrains.kotlinx.lincheck.LinChecker; -import org.jetbrains.kotlinx.lincheck.annotations.OpGroupConfig; -import org.jetbrains.kotlinx.lincheck.annotations.Operation; -import org.jetbrains.kotlinx.lincheck.annotations.Param; -import org.jetbrains.kotlinx.lincheck.paramgen.BooleanGen; -import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingCTest; -import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingOptions; -import org.jetbrains.kotlinx.lincheck.verifier.VerifierState; +import org.jetbrains.lincheck.datastructures.BooleanGen; +import org.jetbrains.lincheck.datastructures.ModelCheckingOptions; +import org.jetbrains.lincheck.datastructures.Operation; +import org.jetbrains.lincheck.datastructures.Param; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -49,18 +48,19 @@ * test all possibly interleaves (on context switch) between the two threads, and then verify the * operations are linearizable in each interleave scenario. */ -@ModelCheckingCTest -@OpGroupConfig(name = "update", nonParallel = true) -@OpGroupConfig(name = "write", nonParallel = true) @Param(name = "keepReady", gen = BooleanGen.class) @RunWith(JUnit4.class) -public class AsyncServletOutputStreamWriterConcurrencyTest extends VerifierState { +public class AsyncServletOutputStreamWriterConcurrencyTest { private static final int OPERATIONS_PER_THREAD = 6; private final AsyncServletOutputStreamWriter writer; private final boolean[] keepReadyArray = new boolean[OPERATIONS_PER_THREAD]; private volatile boolean isReady; + /** + * The container initiates the first call shortly after {@code startAsync}. + */ + private final AtomicBoolean initialOnWritePossible = new AtomicBoolean(true); // when isReadyReturnedFalse, writer.onWritePossible() will be called. private volatile boolean isReadyReturnedFalse; private int producerIndex; @@ -71,17 +71,15 @@ public class AsyncServletOutputStreamWriterConcurrencyTest extends VerifierState public AsyncServletOutputStreamWriterConcurrencyTest() { BiFunction writeAction = (bytes, numBytes) -> () -> { - assertWithMessage("write should only be called while isReady() is true") - .that(isReady) - .isTrue(); + assertTrue("write should only be called while isReady() is true", isReady); // The byte to be written must equal to consumerIndex, otherwise execution order is wrong - assertWithMessage("write in wrong order").that(bytes[0]).isEqualTo((byte) consumerIndex); + assertEquals("write in wrong order", bytes[0], (byte) consumerIndex); bytesWritten++; writeOrFlush(); }; ActionItem flushAction = () -> { - assertWithMessage("flush must only be called while isReady() is true").that(isReady).isTrue(); + assertTrue("flush must only be called while isReady() is true", isReady); writeOrFlush(); }; @@ -102,12 +100,13 @@ private void writeOrFlush() { } private boolean isReady() { - if (!isReady) { - assertWithMessage("isReady() already returned false, onWritePossible() will be invoked") - .that(isReadyReturnedFalse).isFalse(); + boolean copyOfIsReady = isReady; + if (!copyOfIsReady) { + assertFalse("isReady() already returned false, onWritePossible() will be invoked", + isReadyReturnedFalse); isReadyReturnedFalse = true; } - return isReady; + return copyOfIsReady; } /** @@ -118,7 +117,7 @@ private boolean isReady() { * the ServletOutputStream should become unready if keepReady == false. */ // @com.google.errorprone.annotations.Keep - @Operation(group = "write") + @Operation(nonParallelGroup = "write") public void write(@Param(name = "keepReady") boolean keepReady) throws IOException { keepReadyArray[producerIndex] = keepReady; writer.writeBytes(new byte[]{(byte) producerIndex}, 1); @@ -133,7 +132,7 @@ public void write(@Param(name = "keepReady") boolean keepReady) throws IOExcepti * the ServletOutputStream should become unready if keepReady == false. */ // @com.google.errorprone.annotations.Keep // called by lincheck reflectively - @Operation(group = "write") + @Operation(nonParallelGroup = "write") public void flush(@Param(name = "keepReady") boolean keepReady) throws IOException { keepReadyArray[producerIndex] = keepReady; writer.flush(); @@ -142,9 +141,12 @@ public void flush(@Param(name = "keepReady") boolean keepReady) throws IOExcepti /** If the writer is not ready, let it turn ready and call writer.onWritePossible(). */ // @com.google.errorprone.annotations.Keep // called by lincheck reflectively - @Operation(group = "update") + @Operation(nonParallelGroup = "update") public void maybeOnWritePossible() throws IOException { - if (isReadyReturnedFalse) { + if (initialOnWritePossible.compareAndSet(true, false)) { + isReady = true; + writer.onWritePossible(); + } else if (isReadyReturnedFalse) { isReadyReturnedFalse = false; isReady = true; writer.onWritePossible(); @@ -152,7 +154,13 @@ public void maybeOnWritePossible() throws IOException { } @Override - protected Object extractState() { + public final boolean equals(Object o) { + return o instanceof AsyncServletOutputStreamWriterConcurrencyTest + && bytesWritten == ((AsyncServletOutputStreamWriterConcurrencyTest) o).bytesWritten; + } + + @Override + public int hashCode() { return bytesWritten; } @@ -169,6 +177,6 @@ public void linCheck() { AtomicReference.class.getName()) .allMethods() .treatAsAtomic()); - LinChecker.check(AsyncServletOutputStreamWriterConcurrencyTest.class, options); + options.check(AsyncServletOutputStreamWriterConcurrencyTest.class); } } diff --git a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatInteropTest.java b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatInteropTest.java index 1422b5388fd..d072fea93a1 100644 --- a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatInteropTest.java +++ b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatInteropTest.java @@ -113,27 +113,28 @@ protected boolean metricsExpected() { @Test public void gracefulShutdown() {} - // FIXME @Override @Ignore("Tomcat is not able to send trailer only") @Test public void specialStatusMessage() {} - // FIXME @Override @Ignore("Tomcat is not able to send trailer only") @Test public void unimplementedMethod() {} - // FIXME @Override @Ignore("Tomcat is not able to send trailer only") @Test public void statusCodeAndMessage() {} - // FIXME @Override @Ignore("Tomcat is not able to send trailer only") @Test public void emptyStream() {} + + @Override + @Ignore("Tomcat is not able to send trailer only") + @Test + public void timeoutOnSleepingServer() {} } diff --git a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java index 262036883a9..cd73b096ccb 100644 --- a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java +++ b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java @@ -81,7 +81,9 @@ public void start(ServerListener listener) throws IOException { ServerTransportListener serverTransportListener = listener.transportCreated(new ServerTransportImpl(scheduler)); ServletAdapter adapter = - new ServletAdapter(serverTransportListener, streamTracerFactories, Integer.MAX_VALUE); + new ServletAdapter(serverTransportListener, streamTracerFactories, + ServletAdapter.DEFAULT_METHOD_NAME_RESOLVER, + Integer.MAX_VALUE); GrpcServlet grpcServlet = new GrpcServlet(adapter); tomcatServer = new Tomcat(); @@ -91,6 +93,10 @@ public void start(ServerListener listener) throws IOException { .setAsyncSupported(true); ctx.addServletMappingDecoded("/*", "TomcatTransportTest"); tomcatServer.getConnector().addUpgradeProtocol(new Http2Protocol()); + // Workaround for https://github.com/grpc/grpc-java/issues/12540 + // Prevent premature OutputBuffer recycling by disabling facade recycling. + // This should be revisited once the root cause is fixed. + tomcatServer.getConnector().setDiscardFacades(false); try { tomcatServer.start(); } catch (LifecycleException e) { diff --git a/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java b/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java index e14c11985de..ef897c87d70 100644 --- a/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java +++ b/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java @@ -100,7 +100,9 @@ public void start(ServerListener listener) throws IOException { ServerTransportListener serverTransportListener = listener.transportCreated(new ServerTransportImpl(scheduler)); ServletAdapter adapter = - new ServletAdapter(serverTransportListener, streamTracerFactories, Integer.MAX_VALUE); + new ServletAdapter(serverTransportListener, streamTracerFactories, + ServletAdapter.DEFAULT_METHOD_NAME_RESOLVER, + Integer.MAX_VALUE); GrpcServlet grpcServlet = new GrpcServlet(adapter); InstanceFactory instanceFactory = () -> new ImmediateInstanceHandle<>(grpcServlet); diff --git a/settings.gradle b/settings.gradle index 61972e30b6c..51c4bdc0d3d 100644 --- a/settings.gradle +++ b/settings.gradle @@ -1,29 +1,44 @@ pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } plugins { // https://developer.android.com/build/releases/gradle-plugin // 8+ has many changes: https://github.com/grpc/grpc-java/issues/10152 id "com.android.application" version "7.4.1" id "com.android.library" version "7.4.1" - // https://github.com/johnrengelman/shadow/releases - id "com.github.johnrengelman.shadow" version "8.1.1" // https://github.com/kt3k/coveralls-gradle-plugin/tags id "com.github.kt3k.coveralls" version "2.12.2" // https://github.com/GoogleCloudPlatform/appengine-plugins/releases - id "com.google.cloud.tools.appengine" version "2.8.0" + id "com.google.cloud.tools.appengine" version "2.8.6" // https://github.com/GoogleContainerTools/jib/blob/master/jib-gradle-plugin/CHANGELOG.md - id "com.google.cloud.tools.jib" version "3.4.3" + id "com.google.cloud.tools.jib" version "3.5.1" // https://github.com/google/osdetector-gradle-plugin/tags id "com.google.osdetector" version "1.7.3" // https://github.com/google/protobuf-gradle-plugin/releases - id "com.google.protobuf" version "0.9.4" + id "com.google.protobuf" version "0.9.5" + // https://github.com/GradleUp/shadow/releases + // 8.3.2+ requires Java 11+ + // 8.3.1 breaks apache imports for netty/shaded, fixed in 8.3.2 + id "com.gradleup.shadow" version "8.3.0" // https://github.com/melix/japicmp-gradle-plugin/blob/master/CHANGELOG.txt id "me.champeau.gradle.japicmp" version "0.4.2" // https://github.com/melix/jmh-gradle-plugin/releases - id "me.champeau.jmh" version "0.7.2" + id "me.champeau.jmh" version "0.7.3" // https://github.com/tbroyer/gradle-errorprone-plugin/releases - id "net.ltgt.errorprone" version "4.0.1" + id "net.ltgt.errorprone" version "4.3.0" // https://github.com/xvik/gradle-animalsniffer-plugin/releases - id "ru.vyarus.animalsniffer" version "1.7.1" + id "ru.vyarus.animalsniffer" version "2.0.1" } resolutionStrategy { eachPlugin { @@ -65,6 +80,7 @@ include ":grpc-benchmarks" include ":grpc-services" include ":grpc-servlet" include ":grpc-servlet-jakarta" +include ":grpc-s2a" include ":grpc-xds" include ":grpc-bom" include ":grpc-rls" @@ -76,6 +92,7 @@ include ":grpc-istio-interop-testing" include ":grpc-inprocess" include ":grpc-util" include ":grpc-opentelemetry" +include ":grpc-context-override-opentelemetry" project(':grpc-api').projectDir = "$rootDir/api" as File project(':grpc-core').projectDir = "$rootDir/core" as File @@ -100,6 +117,7 @@ project(':grpc-benchmarks').projectDir = "$rootDir/benchmarks" as File project(':grpc-services').projectDir = "$rootDir/services" as File project(':grpc-servlet').projectDir = "$rootDir/servlet" as File project(':grpc-servlet-jakarta').projectDir = "$rootDir/servlet/jakarta" as File +project(':grpc-s2a').projectDir = "$rootDir/s2a" as File project(':grpc-xds').projectDir = "$rootDir/xds" as File project(':grpc-bom').projectDir = "$rootDir/bom" as File project(':grpc-rls').projectDir = "$rootDir/rls" as File @@ -111,6 +129,7 @@ project(':grpc-istio-interop-testing').projectDir = "$rootDir/istio-interop-test project(':grpc-inprocess').projectDir = "$rootDir/inprocess" as File project(':grpc-util').projectDir = "$rootDir/util" as File project(':grpc-opentelemetry').projectDir = "$rootDir/opentelemetry" as File +project(':grpc-context-override-opentelemetry').projectDir = "$rootDir/contextstorage" as File if (settings.hasProperty('skipCodegen') && skipCodegen.toBoolean()) { println '*** Skipping the build of codegen and compilation of proto files because skipCodegen=true' diff --git a/stub/BUILD.bazel b/stub/BUILD.bazel index 6d06e01f918..f9188c27272 100644 --- a/stub/BUILD.bazel +++ b/stub/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( @@ -12,14 +13,6 @@ java_library( artifact("com.google.code.findbugs:jsr305"), artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), + artifact("org.codehaus.mojo:animal-sniffer-annotations"), ], ) - -# javax.annotation.Generated is not included in the default root modules in 9, -# see: http://openjdk.java.net/jeps/320. -java_library( - name = "javax_annotation", - neverlink = 1, # @Generated is source-retention - visibility = ["//visibility:public"], - exports = [artifact("org.apache.tomcat:annotations-api")], -) diff --git a/stub/build.gradle b/stub/build.gradle index 867936f3ea3..2dabd9e6202 100644 --- a/stub/build.gradle +++ b/stub/build.gradle @@ -16,14 +16,23 @@ tasks.named("jar").configure { dependencies { api project(':grpc-api'), + libraries.animalsniffer.annotations, libraries.guava implementation libraries.errorprone.annotations testImplementation libraries.truth, project(':grpc-inprocess'), project(':grpc-testing'), testFixtures(project(':grpc-api')) - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("javadoc").configure { diff --git a/stub/src/main/java/io/grpc/stub/AbstractAsyncStub.java b/stub/src/main/java/io/grpc/stub/AbstractAsyncStub.java index c6f912cb3a7..f369eeaf87f 100644 --- a/stub/src/main/java/io/grpc/stub/AbstractAsyncStub.java +++ b/stub/src/main/java/io/grpc/stub/AbstractAsyncStub.java @@ -16,10 +16,10 @@ package io.grpc.stub; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.stub.ClientCalls.StubType; -import javax.annotation.CheckReturnValue; import javax.annotation.concurrent.ThreadSafe; /** diff --git a/stub/src/main/java/io/grpc/stub/AbstractBlockingStub.java b/stub/src/main/java/io/grpc/stub/AbstractBlockingStub.java index 1cb919e67b0..4bdb3c0bb94 100644 --- a/stub/src/main/java/io/grpc/stub/AbstractBlockingStub.java +++ b/stub/src/main/java/io/grpc/stub/AbstractBlockingStub.java @@ -16,10 +16,10 @@ package io.grpc.stub; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.stub.ClientCalls.StubType; -import javax.annotation.CheckReturnValue; import javax.annotation.concurrent.ThreadSafe; /** diff --git a/stub/src/main/java/io/grpc/stub/AbstractFutureStub.java b/stub/src/main/java/io/grpc/stub/AbstractFutureStub.java index 66570bcd6ff..5e37b1e4915 100644 --- a/stub/src/main/java/io/grpc/stub/AbstractFutureStub.java +++ b/stub/src/main/java/io/grpc/stub/AbstractFutureStub.java @@ -16,10 +16,10 @@ package io.grpc.stub; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.stub.ClientCalls.StubType; -import javax.annotation.CheckReturnValue; import javax.annotation.concurrent.ThreadSafe; /** diff --git a/stub/src/main/java/io/grpc/stub/AbstractStub.java b/stub/src/main/java/io/grpc/stub/AbstractStub.java index 0b6f86f2acf..697107760db 100644 --- a/stub/src/main/java/io/grpc/stub/AbstractStub.java +++ b/stub/src/main/java/io/grpc/stub/AbstractStub.java @@ -17,7 +17,9 @@ package io.grpc.stub; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.InternalTimeUtils.convert; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.Channel; @@ -26,11 +28,12 @@ import io.grpc.Deadline; import io.grpc.ExperimentalApi; import io.grpc.ManagedChannelBuilder; +import java.time.Duration; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** * Common base type for stub implementations. Stub configuration is immutable; changing the @@ -149,6 +152,12 @@ public final S withDeadlineAfter(long duration, TimeUnit unit) { return build(channel, callOptions.withDeadlineAfter(duration, unit)); } + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11657") + @IgnoreJRERequirement + public final S withDeadlineAfter(Duration duration) { + return withDeadlineAfter(convert(duration), TimeUnit.NANOSECONDS); + } + /** * Returns a new stub with the given executor that is to be used instead of the default one * specified with {@link ManagedChannelBuilder#executor}. Note that setting this option may not diff --git a/stub/src/main/java/io/grpc/stub/BlockingClientCall.java b/stub/src/main/java/io/grpc/stub/BlockingClientCall.java new file mode 100644 index 00000000000..6a52ce50776 --- /dev/null +++ b/stub/src/main/java/io/grpc/stub/BlockingClientCall.java @@ -0,0 +1,352 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.stub; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import io.grpc.ClientCall; +import io.grpc.ExperimentalApi; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.stub.ClientCalls.ThreadSafeThreadlessExecutor; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Represents a bidirectional streaming call from a client. Allows in a blocking manner, sending + * over the stream and receiving from the stream. Also supports terminating the call. + * Wraps a ClientCall and converts from async communication to the sync paradigm used by the + * various blocking stream methods in {@link ClientCalls} which are used by the generated stubs. + * + *

Supports separate threads for reads and writes, but only 1 of each + * + *

Read methods consist of: + *

    + *
  • {@link #read()} + *
  • {@link #read(long timeout, TimeUnit unit)} + *
  • {@link #hasNext()} + *
  • {@link #cancel(String, Throwable)} + *
+ * + *

Write methods consist of: + *

    + *
  • {@link #write(Object)} + *
  • {@link #write(Object, long timeout, TimeUnit unit)} + *
  • {@link #halfClose()} + *
+ * + * @param Type of the Request Message + * @param Type of the Response Message + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") +public final class BlockingClientCall { + + private static final Logger logger = Logger.getLogger(BlockingClientCall.class.getName()); + + private final BlockingQueue buffer; + private final ClientCall call; + + private final ThreadSafeThreadlessExecutor executor; + + private boolean writeClosed; + private AtomicReference closeState = new AtomicReference<>(); + + BlockingClientCall(ClientCall call, ThreadSafeThreadlessExecutor executor) { + this.call = call; + this.executor = executor; + buffer = new ArrayBlockingQueue<>(1); + } + + /** + * Wait if necessary for a value to be available from the server. If there is an available value + * return it immediately, if the stream is closed return a null. Otherwise, wait for a value to be + * available or the stream to be closed + * + * @return value from server or null if stream has been closed + * @throws StatusException If the stream has closed in an error state + */ + public RespT read() throws InterruptedException, StatusException { + try { + return read(true, 0); + } catch (TimeoutException e) { + throw new AssertionError("should never happen", e); + } + } + + /** + * Wait with timeout, if necessary, for a value to be available from the server. If there is an + * available value, return it immediately. If the stream is closed return a null. Otherwise, wait + * for a value to be available, the stream to be closed or the timeout to expire. + * + * @param timeout how long to wait before giving up. Values <= 0 are no wait + * @param unit a TimeUnit determining how to interpret the timeout parameter + * @return value from server or null (if stream has been closed) + * @throws TimeoutException if no read becomes ready before the specified timeout expires + * @throws StatusException If the stream has closed in an error state + */ + public RespT read(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException, + StatusException { + long endNanoTime = System.nanoTime() + unit.toNanos(timeout); + return read(false, endNanoTime); + } + + private RespT read(boolean waitForever, long endNanoTime) + throws InterruptedException, TimeoutException, StatusException { + Predicate> predicate = BlockingClientCall::skipWaitingForRead; + executor.waitAndDrainWithTimeout(waitForever, endNanoTime, predicate, this); + RespT bufferedValue = buffer.poll(); + + if (logger.isLoggable(Level.FINER)) { + logger.finer("Client Blocking read had value: " + bufferedValue); + } + + CloseState currentCloseState; + if (bufferedValue != null) { + call.request(1); + return bufferedValue; + } else if ((currentCloseState = closeState.get()) == null) { + throw new IllegalStateException( + "The message disappeared... are you reading from multiple threads?"); + } else if (!currentCloseState.status.isOk()) { + throw currentCloseState.status.asException(currentCloseState.trailers); + } else { + return null; + } + } + + boolean skipWaitingForRead() { + return closeState.get() != null || !buffer.isEmpty(); + } + + /** + * Wait for a value to be available from the server. If there is an + * available value, return true immediately. If the stream was closed with Status.OK, return + * false. If the stream was closed with an error status, throw a StatusException. Otherwise, wait + * for a value to be available or the stream to be closed. + * + * @return True when there is a value to read. Return false if stream closed cleanly. + * @throws StatusException If the stream was closed in an error state + */ + public boolean hasNext() throws InterruptedException, StatusException { + executor.waitAndDrain((x) -> !x.buffer.isEmpty() || x.closeState.get() != null, this); + + CloseState currentCloseState = closeState.get(); + if (currentCloseState != null && !currentCloseState.status.isOk()) { + throw currentCloseState.status.asException(currentCloseState.trailers); + } + + return !buffer.isEmpty(); + } + + /** + * Send a value to the stream for sending to server, wait if necessary for the grpc stream to be + * ready. + * + *

If write is not legal at the time of call, immediately returns false + * + *


NOTE: This method will return as soon as it passes the request to the grpc stream + * layer. It will not block while the message is being sent on the wire and returning true does + * not guarantee that the server gets the message. + * + *


WARNING: Doing only writes without reads can lead to deadlocks. This is because + * flow control, imposed by networks to protect intermediary routers and endpoints that are + * operating under resource constraints, requires reads to be done in order to progress writes. + * Furthermore, the server closing the stream will only be identified after + * the last sent value is read. + * + * @param request Message to send to the server + * @return true if the request is sent to stream, false if skipped + * @throws StatusException If the stream has closed in an error state + */ + public boolean write(ReqT request) throws InterruptedException, StatusException { + try { + return write(true, request, 0); + } catch (TimeoutException e) { + throw new RuntimeException(e); // should never happen + } + } + + /** + * Send a value to the stream for sending to server, wait if necessary for the grpc stream to be + * ready up to specified timeout. + * + *

If write is not legal at the time of call, immediately returns false + * + *


NOTE: This method will return as soon as it passes the request to the grpc stream + * layer. It will not block while the message is being sent on the wire and returning true does + * not guarantee that the server gets the message. + * + *


WARNING: Doing only writes without reads can lead to deadlocks as a result of + * flow control. Furthermore, the server closing the stream will only be identified after the + * last sent value is read. + * + * @param request Message to send to the server + * @param timeout How long to wait before giving up. Values <= 0 are no wait + * @param unit A TimeUnit determining how to interpret the timeout parameter + * @return true if the request is sent to stream, false if skipped + * @throws TimeoutException if write does not become ready before the specified timeout expires + * @throws StatusException If the stream has closed in an error state + */ + public boolean write(ReqT request, long timeout, TimeUnit unit) + throws InterruptedException, TimeoutException, StatusException { + long endNanoTime = System.nanoTime() + unit.toNanos(timeout); + return write(false, request, endNanoTime); + } + + private boolean write(boolean waitForever, ReqT request, long endNanoTime) + throws InterruptedException, TimeoutException, StatusException { + + if (writeClosed) { + throw new IllegalStateException("Writes cannot be done after calling halfClose or cancel"); + } + + Predicate> predicate = + (x) -> x.call.isReady() || x.closeState.get() != null; + executor.waitAndDrainWithTimeout(waitForever, endNanoTime, predicate, this); + CloseState savedCloseState = closeState.get(); + if (savedCloseState == null) { + call.sendMessage(request); + return true; + } else if (savedCloseState.status.isOk()) { + return false; + } else { + throw savedCloseState.status.asException(savedCloseState.trailers); + } + } + + void sendSingleRequest(ReqT request) { + call.sendMessage(request); + } + + /** + * Cancel stream and stop any further writes. Note that some reads that are in flight may still + * happen after the cancel. + * + * @param message if not {@code null}, will appear as the description of the CANCELLED status + * @param cause if not {@code null}, will appear as the cause of the CANCELLED status + */ + public void cancel(String message, Throwable cause) { + writeClosed = true; + call.cancel(message, cause); + } + + /** + * Indicate that no more writes will be done and the stream will be closed from the client side. + * + * @see ClientCall#halfClose() + */ + public void halfClose() { + if (writeClosed) { + throw new IllegalStateException( + "halfClose cannot be called after already half closed or cancelled"); + } + + writeClosed = true; + call.halfClose(); + } + + /** + * Status that server sent when closing channel from its side. + * + * @return null if stream not closed by server, otherwise Status sent by server + */ + @VisibleForTesting + Status getClosedStatus() { + executor.drain(); + CloseState state = closeState.get(); + return (state == null) ? null : state.status; + } + + /** + * Check for whether some action is ready. + * + * @return True if legal to write and writeOrRead can run without blocking + */ + @VisibleForTesting + boolean isEitherReadOrWriteReady() { + return (isWriteLegal() && isWriteReady()) || isReadReady(); + } + + /** + * Check whether there are any values waiting to be read. + * + * @return true if read will not block + */ + @VisibleForTesting + boolean isReadReady() { + executor.drain(); + + return !buffer.isEmpty(); + } + + /** + * Check that write hasn't been marked complete and stream is ready to receive a write (so will + * not block). + * + * @return true if legal to write and write will not block + */ + @VisibleForTesting + boolean isWriteReady() { + executor.drain(); + + return isWriteLegal() && call.isReady(); + } + + /** + * Check whether we'll ever be able to do writes or should terminate. + * @return True if writes haven't been closed and the server hasn't closed the stream + */ + private boolean isWriteLegal() { + return !writeClosed && closeState.get() == null; + } + + ClientCall.Listener getListener() { + return new QueuingListener(); + } + + private final class QueuingListener extends ClientCall.Listener { + @Override + public void onMessage(RespT value) { + Preconditions.checkState(closeState.get() == null, "ClientCall already closed"); + buffer.add(value); + } + + @Override + public void onClose(Status status, Metadata trailers) { + CloseState newCloseState = new CloseState(status, trailers); + boolean wasSet = closeState.compareAndSet(null, newCloseState); + Preconditions.checkState(wasSet, "ClientCall already closed"); + } + } + + private static final class CloseState { + final Status status; + final Metadata trailers; + + CloseState(Status status, Metadata trailers) { + this.status = Preconditions.checkNotNull(status, "status"); + this.trailers = trailers; + } + } +} diff --git a/stub/src/main/java/io/grpc/stub/ClientCalls.java b/stub/src/main/java/io/grpc/stub/ClientCalls.java index 13fb00d3b3e..ff2804a0a1f 100644 --- a/stub/src/main/java/io/grpc/stub/ClientCalls.java +++ b/stub/src/main/java/io/grpc/stub/ClientCalls.java @@ -22,12 +22,14 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; import com.google.common.base.Strings; import com.google.common.util.concurrent.AbstractFuture; import com.google.common.util.concurrent.ListenableFuture; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; +import io.grpc.ExperimentalApi; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -42,9 +44,14 @@ import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.LockSupport; +import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; import java.util.logging.Logger; +import javax.annotation.Nonnull; import javax.annotation.Nullable; /** @@ -175,6 +182,23 @@ public static RespT blockingUnaryCall( } } + /** + * Executes a unary call and blocks on the response, + * throws a checked {@link StatusException}. + * + * @return the single response message. + * @throws StatusException on error + */ + public static RespT blockingV2UnaryCall( + Channel channel, MethodDescriptor method, CallOptions callOptions, ReqT req) + throws StatusException { + try { + return blockingUnaryCall(channel, method, callOptions, req); + } catch (StatusRuntimeException e) { + throw e.getStatus().asException(e.getTrailers()); + } + } + /** * Executes a server-streaming call returning a blocking {@link Iterator} over the * response stream. The {@code call} should not be already started. After calling this method, @@ -184,7 +208,6 @@ public static RespT blockingUnaryCall( * * @return an iterator over the response stream. */ - // TODO(louiscryan): Not clear if we want to use this idiom for 'simple' stubs. public static Iterator blockingServerStreamingCall( ClientCall call, ReqT req) { BlockingResponseStream result = new BlockingResponseStream<>(call); @@ -194,11 +217,12 @@ public static Iterator blockingServerStreamingCall( /** * Executes a server-streaming call returning a blocking {@link Iterator} over the - * response stream. The {@code call} should not be already started. After calling this method, - * {@code call} should no longer be used. + * response stream. * *

The returned iterator may throw {@link StatusRuntimeException} on error. * + *

Warning: the iterator can result in leaks if not completely consumed. + * * @return an iterator over the response stream. */ public static Iterator blockingServerStreamingCall( @@ -211,6 +235,82 @@ public static Iterator blockingServerStreamingCall( return result; } + /** + * Initiates a client streaming call over the specified channel. It returns an + * object which can be used in a blocking manner to retrieve responses.. + * + *

The methods {@link BlockingClientCall#hasNext()} and {@link + * BlockingClientCall#cancel(String, Throwable)} can be used for more extensive control. + * + * @return A {@link BlockingClientCall} that has had the request sent and halfClose called + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public static BlockingClientCall blockingV2ServerStreamingCall( + Channel channel, MethodDescriptor method, CallOptions callOptions, ReqT req) { + BlockingClientCall call = + blockingBidiStreamingCall(channel, method, callOptions); + + call.sendSingleRequest(req); + call.halfClose(); + return call; + } + + /** + * Initiates a server streaming call and sends the specified request to the server. It returns an + * object which can be used in a blocking manner to retrieve values from the server. After the + * last value has been read, the next read call will return null. + * + *

Call {@link BlockingClientCall#read()} for + * retrieving values. A {@code null} will be returned after the server has closed the stream. + * + *

The methods {@link BlockingClientCall#hasNext()} and {@link + * BlockingClientCall#cancel(String, Throwable)} can be used for more extensive control. + * + *


Example usage: + *

 {@code  while ((response = call.read()) != null) { ... } } 
+ * or + *
 {@code
+   *   while (call.hasNext()) {
+   *     response = call.read();
+   *     ...
+   *   }
+   * } 
+ * + *

Note that this paradigm is different from the original + * {@link #blockingServerStreamingCall(Channel, MethodDescriptor, CallOptions, Object)} + * which returns an iterator, which would leave the stream open if not completely consumed. + * + * @return A {@link BlockingClientCall} which can be used by the client to write and receive + * messages over the grpc channel. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public static BlockingClientCall blockingClientStreamingCall( + Channel channel, MethodDescriptor method, CallOptions callOptions) { + return blockingBidiStreamingCall(channel, method, callOptions); + } + + /** + * Initiate a bidirectional-streaming {@link ClientCall} and returning a stream object + * ({@link BlockingClientCall}) which can be used by the client to send and receive messages over + * the grpc channel. + * + * @return an object representing the call which can be used to read, write and terminate it. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public static BlockingClientCall blockingBidiStreamingCall( + Channel channel, MethodDescriptor method, CallOptions callOptions) { + ThreadSafeThreadlessExecutor executor = new ThreadSafeThreadlessExecutor(); + ClientCall call = channel.newCall(method, callOptions.withExecutor(executor)); + + BlockingClientCall blockingClientCall = new BlockingClientCall<>(call, executor); + + // Get the call started + call.start(blockingClientCall.getListener(), new Metadata()); + call.request(1); + + return blockingClientCall; + } + /** * Executes a unary call and returns a {@link ListenableFuture} to the response. The * {@code call} should not be already started. After calling this method, {@code call} should no @@ -414,7 +514,7 @@ public void disableAutoRequestWithInitial(int request) { public void request(int count) { if (!streamingResponse && count == 1) { // Initially ask for two responses from flow-control so that if a misbehaving server - // sends more than one responses, we can catch it and fail it in the listener. + // sends more than one response, we can catch it and fail it in the listener. call.request(2); } else { call.request(count); @@ -637,7 +737,7 @@ public boolean hasNext() { public T next() { // Eagerly call request(1) so it can be processing the next message while we wait for the // current one, which reduces latency for the next message. With MigratingThreadDeframer and - // if the data has already been recieved, every other message can be delivered instantly. This + // if the data has already been received, every other message can be delivered instantly. This // can be run after hasNext(), but just would be slower. if (!(last instanceof StatusRuntimeException) && last != this) { call.request(1); @@ -726,6 +826,12 @@ public void waitAndDrain() throws InterruptedException { } while ((runnable = poll()) != null); } + private static void throwIfInterrupted() throws InterruptedException { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + } + /** * Called after final call to {@link #waitAndDrain()}, from same thread. */ @@ -745,12 +851,6 @@ private static void runQuietly(Runnable runnable) { } } - private static void throwIfInterrupted() throws InterruptedException { - if (Thread.interrupted()) { - throw new InterruptedException(); - } - } - @Override public void execute(Runnable runnable) { add(runnable); @@ -763,6 +863,128 @@ public void execute(Runnable runnable) { } } + @SuppressWarnings("serial") + static final class ThreadSafeThreadlessExecutor extends ConcurrentLinkedQueue + implements Executor { + private static final Logger log = + Logger.getLogger(ThreadSafeThreadlessExecutor.class.getName()); + + private final Lock waiterLock = new ReentrantLock(); + private final Condition waiterCondition = waiterLock.newCondition(); + + // Non private to avoid synthetic class + ThreadSafeThreadlessExecutor() {} + + /** + * Waits until there is a Runnable, then executes it and all queued Runnables after it. + */ + public void waitAndDrain(Predicate predicate, T testTarget) throws InterruptedException { + try { + waitAndDrainWithTimeout(true, 0, predicate, testTarget); + } catch (TimeoutException e) { + throw new AssertionError(e); // Should never happen + } + } + + /** + * Waits for up to specified nanoseconds until there is a Runnable, then executes it and all + * queued Runnables after it. + * + *

his should always be called in a loop that checks whether the reason we are waiting has + * been satisfied.

T + * + * @param waitForever ignore the rest of the arguments and wait until there is a task to run + * @param end System.nanoTime() to stop waiting if haven't been woken up yet + * @param predicate non-null condition to test for skipping wake or waking up threads + * @param testTarget object to pass to predicate + */ + public void waitAndDrainWithTimeout(boolean waitForever, long end, + @Nonnull Predicate predicate, T testTarget) + throws InterruptedException, TimeoutException { + throwIfInterrupted(); + Runnable runnable; + + while (!predicate.apply(testTarget)) { + waiterLock.lock(); + try { + while ((runnable = poll()) == null) { + if (predicate.apply(testTarget)) { + return; // The condition for which we were waiting is now satisfied + } + + if (waitForever) { + waiterCondition.await(); + } else { + long waitNanos = end - System.nanoTime(); + if (waitNanos <= 0) { + throw new TimeoutException(); // Deadline is expired + } + waiterCondition.awaitNanos(waitNanos); + } + } + } finally { + waiterLock.unlock(); + } + + do { + runQuietly(runnable); + } while ((runnable = poll()) != null); + // Wake everything up now that we've done something and they can check in their outer loop + // if they can continue or need to wait again. + signalAll(); + } + } + + /** Executes all queued Runnables and if there were any wakes up any waiting threads. */ + void drain() { + Runnable runnable; + boolean didWork = false; + + while ((runnable = poll()) != null) { + runQuietly(runnable); + didWork = true; + } + + if (didWork) { + signalAll(); + } + } + + private void signalAll() { + waiterLock.lock(); + try { + waiterCondition.signalAll(); + } finally { + waiterLock.unlock(); + } + } + + private static void runQuietly(Runnable runnable) { + try { + runnable.run(); + } catch (Throwable t) { + log.log(Level.WARNING, "Runnable threw exception", t); + } + } + + private static void throwIfInterrupted() throws InterruptedException { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + } + + @Override + public void execute(Runnable runnable) { + waiterLock.lock(); + try { + add(runnable); + waiterCondition.signalAll(); // If anything is waiting let it wake up and process this task + } finally { + waiterLock.unlock(); + } + } + } + enum StubType { BLOCKING, FUTURE, ASYNC } diff --git a/stub/src/main/java/io/grpc/stub/ServerCalls.java b/stub/src/main/java/io/grpc/stub/ServerCalls.java index 7990a5b34c0..9f0063713cc 100644 --- a/stub/src/main/java/io/grpc/stub/ServerCalls.java +++ b/stub/src/main/java/io/grpc/stub/ServerCalls.java @@ -382,9 +382,10 @@ public void onNext(RespT response) { @Override public void onError(Throwable t) { - Metadata metadata = Status.trailersFromThrowable(t); - if (metadata == null) { - metadata = new Metadata(); + Metadata metadata = new Metadata(); + Metadata trailers = Status.trailersFromThrowable(t); + if (trailers != null) { + metadata.merge(trailers); } call.close(Status.fromThrowable(t), metadata); aborted = true; diff --git a/stub/src/main/java/io/grpc/stub/StreamObservers.java b/stub/src/main/java/io/grpc/stub/StreamObservers.java index 2cc53ea0aa2..a421d3eca2f 100644 --- a/stub/src/main/java/io/grpc/stub/StreamObservers.java +++ b/stub/src/main/java/io/grpc/stub/StreamObservers.java @@ -23,12 +23,21 @@ /** * Utility functions for working with {@link StreamObserver} and it's common subclasses like * {@link CallStreamObserver}. - * - * @deprecated Of questionable utility and generally not used. */ -@Deprecated -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/4694") public final class StreamObservers { + // Prevent instantiation + private StreamObservers() { } + + /** + * Utility method to call {@link StreamObserver#onNext(Object)} and + * {@link StreamObserver#onCompleted()} on the specified responseObserver. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10957") + public static void nextAndComplete(StreamObserver responseObserver, T response) { + responseObserver.onNext(response); + responseObserver.onCompleted(); + } + /** * Copy the values of an {@link Iterator} to the target {@link CallStreamObserver} while properly * accounting for outbound flow-control. After calling this method, {@code target} should no @@ -40,7 +49,10 @@ public final class StreamObservers { * * @param source of values expressed as an {@link Iterator}. * @param target {@link CallStreamObserver} which accepts values from the source. + * @deprecated Of questionable utility and generally not used. */ + @Deprecated + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/4694") public static void copyWithFlowControl(final Iterator source, final CallStreamObserver target) { Preconditions.checkNotNull(source, "source"); @@ -80,7 +92,10 @@ public void run() { * * @param source of values expressed as an {@link Iterable}. * @param target {@link CallStreamObserver} which accepts values from the source. + * @deprecated Of questionable utility and generally not used. */ + @Deprecated + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/4694") public static void copyWithFlowControl(final Iterable source, CallStreamObserver target) { Preconditions.checkNotNull(source, "source"); diff --git a/stub/src/test/java/io/grpc/stub/AbstractStubTest.java b/stub/src/test/java/io/grpc/stub/AbstractStubTest.java index 9006b8679e4..352a2fb7fe2 100644 --- a/stub/src/test/java/io/grpc/stub/AbstractStubTest.java +++ b/stub/src/test/java/io/grpc/stub/AbstractStubTest.java @@ -16,12 +16,19 @@ package io.grpc.stub; +import static com.google.common.truth.Truth.assertAbout; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.testing.DeadlineSubject.deadline; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.Deadline; import io.grpc.stub.AbstractStub.StubFactory; import io.grpc.stub.AbstractStubTest.NoopStub; +import java.time.Duration; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -47,8 +54,23 @@ public NoopStub newStub(Channel channel, CallOptions callOptions) { .isNull(); } - class NoopStub extends AbstractStub { + @Test + @IgnoreJRERequirement + public void testDuration() { + NoopStub stub = NoopStub.newStub(new StubFactory() { + @Override + public NoopStub newStub(Channel channel, CallOptions callOptions) { + return create(channel, callOptions); + } + }, channel, CallOptions.DEFAULT); + NoopStub stubInstance = stub.withDeadlineAfter(Duration.ofMinutes(1L)); + Deadline actual = stubInstance.getCallOptions().getDeadline(); + Deadline expected = Deadline.after(1, MINUTES); + assertAbout(deadline()).that(actual).isWithin(10, MILLISECONDS).of(expected); + } + + class NoopStub extends AbstractStub { NoopStub(Channel channel, CallOptions options) { super(channel, options); } diff --git a/stub/src/test/java/io/grpc/stub/BlockingClientCallTest.java b/stub/src/test/java/io/grpc/stub/BlockingClientCallTest.java new file mode 100644 index 00000000000..e3a4f90e2c2 --- /dev/null +++ b/stub/src/test/java/io/grpc/stub/BlockingClientCallTest.java @@ -0,0 +1,499 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.stub; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import io.grpc.CallOptions; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.Server; +import io.grpc.ServerServiceDefinition; +import io.grpc.ServiceDescriptor; +import io.grpc.Status; +import io.grpc.Status.Code; +import io.grpc.StatusException; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.ServerCalls.BidiStreamingMethod; +import io.grpc.stub.ServerCallsTest.IntegerMarshaller; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class BlockingClientCallTest { + private static final Logger logger = Logger.getLogger(BlockingClientCallTest.class.getName()); + + public static final int DELAY_MILLIS = 2000; + public static final long DELAY_NANOS = TimeUnit.MILLISECONDS.toNanos(DELAY_MILLIS); + private static final MethodDescriptor BIDI_STREAMING_METHOD = + MethodDescriptor.newBuilder() + .setType(MethodType.BIDI_STREAMING) + .setFullMethodName("some/method") + .setRequestMarshaller(new IntegerMarshaller()) + .setResponseMarshaller(new IntegerMarshaller()) + .build(); + + private Server server; + + private ManagedChannel channel; + + private IntegerTestMethod testMethod; + private BlockingClientCall biDiStream; + + @Before + public void setUp() throws Exception { + testMethod = new IntegerTestMethod(); + + ServerServiceDefinition service = ServerServiceDefinition.builder( + new ServiceDescriptor("some", BIDI_STREAMING_METHOD)) + .addMethod(BIDI_STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall(testMethod)) + .build(); + long tag = System.nanoTime(); + + server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor() + .addService(service).build().start(); + + channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build(); + } + + @After + public void tearDown() { + if (server != null) { + server.shutdownNow(); + } + if (channel != null) { + channel.shutdownNow(); + } + if (biDiStream != null) { + biDiStream.cancel("In teardown", null); + } + } + + @Test + public void sanityTest() throws Exception { + Integer req = 2; + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + // verify activity ready + assertTrue(biDiStream.isEitherReadOrWriteReady()); + assertTrue(biDiStream.isWriteReady()); + + // Have server send a value + testMethod.sendValueToClient(10); + + // Do a writeOrRead + biDiStream.write(req, 3, TimeUnit.SECONDS); + assertEquals(Integer.valueOf(10), biDiStream.read(DELAY_MILLIS, TimeUnit.MILLISECONDS)); + + // mark complete + biDiStream.halfClose(); + assertNull(biDiStream.read(2, TimeUnit.SECONDS)); + + // verify activity !ready and !writeable + assertFalse(biDiStream.isEitherReadOrWriteReady()); + assertFalse(biDiStream.isWriteReady()); + + assertEquals(Code.OK, biDiStream.getClosedStatus().getCode()); + } + + @Test + public void testReadSuccess_withoutBlocking() throws Exception { + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + // Have server push a value + testMethod.sendValueToClient(11); + + long start = System.nanoTime(); + Integer value = biDiStream.read(100, TimeUnit.SECONDS); + assertNotNull(value); + long timeTaken = System.nanoTime() - start; + assertThat(timeTaken).isLessThan(TimeUnit.MILLISECONDS.toNanos(100)); + } + + @Test + public void testReadSuccess_withBlocking() throws Exception { + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + try { + biDiStream.read(1, TimeUnit.SECONDS); + fail("Expected timeout"); + } catch (TimeoutException t) { + // ignore + } + + long start = System.nanoTime(); + delayedAddValue(DELAY_MILLIS, 12); + assertNotNull(biDiStream.read(DELAY_MILLIS * 2, TimeUnit.MILLISECONDS)); + long timeTaken = System.nanoTime() - start; + assertThat(timeTaken).isGreaterThan(DELAY_NANOS); + assertThat(timeTaken).isLessThan(DELAY_NANOS * 2); + + start = System.nanoTime(); + Integer[] values = {13, 14, 15, 16}; + delayedAddValue(DELAY_MILLIS, values); + for (Integer value : values) { + Integer readValue = biDiStream.read(DELAY_MILLIS * 2, TimeUnit.MILLISECONDS); + assertEquals(value, readValue); + } + timeTaken = System.nanoTime() - start; + assertThat(timeTaken).isLessThan(DELAY_NANOS * 2); + assertThat(timeTaken).isAtLeast(DELAY_NANOS); + + start = System.nanoTime(); + delayedVoidMethod(100, testMethod::halfClose); + assertNull(biDiStream.read(DELAY_MILLIS * 2, TimeUnit.MILLISECONDS)); + timeTaken = System.nanoTime() - start; + assertThat(timeTaken).isLessThan(DELAY_NANOS); + } + + @Test + public void testCancel() throws Exception { + testMethod.disableAutoRequest(); + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + // read terminated + long start = System.currentTimeMillis(); + delayedCancel(biDiStream, "cancel read"); + try { + assertNull(biDiStream.read(2 * DELAY_MILLIS, TimeUnit.MILLISECONDS)); + fail("No exception thrown by read after cancel"); + } catch (StatusException e) { + assertEquals(Status.CANCELLED.getCode(), e.getStatus().getCode()); + assertThat(System.currentTimeMillis() - start).isLessThan(2 * DELAY_MILLIS); + } + + // after cancel tests + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + biDiStream.cancel("cancel write", new RuntimeException("Test requested close")); + + // Write after cancel should throw an exception + try { + start = System.currentTimeMillis(); + biDiStream.write(30); + fail("No exception doing write after cancel"); + } catch (IllegalStateException e) { + assertThat(System.currentTimeMillis() - start).isLessThan(200); + assertThat(e.getMessage()).contains("cancel"); + } + + // new read after cancel immediately throws an exception + try { + start = System.currentTimeMillis(); + assertNull(biDiStream.read(2, TimeUnit.SECONDS)); + } catch (StatusException e) { + assertEquals(Status.CANCELLED.getCode(), e.getStatus().getCode()); + assertThat(System.currentTimeMillis() - start).isLessThan(200); + } + + } + + @Test + public void testIsActivityReady() throws Exception { + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + // write only ready + assertTrue(biDiStream.isEitherReadOrWriteReady()); + assertTrue(biDiStream.isWriteReady()); + assertFalse(biDiStream.isReadReady()); + + // both ready + testMethod.sendValueToClient(40); + assertTrue(biDiStream.isEitherReadOrWriteReady()); + assertTrue(biDiStream.isReadReady()); + assertTrue(biDiStream.isWriteReady()); + + // read only ready + biDiStream.halfClose(); + assertTrue(biDiStream.isEitherReadOrWriteReady()); + assertTrue(biDiStream.isReadReady()); + assertFalse(biDiStream.isWriteReady()); + + // Neither ready + assertNotNull(biDiStream.read(1, TimeUnit.MILLISECONDS)); + assertFalse(biDiStream.isEitherReadOrWriteReady()); + assertFalse(biDiStream.isReadReady()); + assertFalse(biDiStream.isWriteReady()); + } + + @Test + public void testWriteSuccess_withBlocking() throws Exception { + testMethod.disableAutoRequest(); + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + assertFalse(biDiStream.isWriteReady()); + delayedWriteEnable(500); + assertTrue(biDiStream.write(40)); + + delayedWriteEnable(500); + assertTrue(biDiStream.write(41, 0, TimeUnit.NANOSECONDS)); + } + + + @Test + public void testReadNonblocking_whenWriteBlocked() throws Exception { + testMethod.disableAutoRequest(); + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + // One value waiting + testMethod.sendValueToClient(50); + long start = System.currentTimeMillis(); + assertEquals(Integer.valueOf(50), biDiStream.read()); + assertThat(System.currentTimeMillis() - start).isLessThan(DELAY_MILLIS); + + // Two values waiting + start = System.currentTimeMillis(); + testMethod.sendValuesToClient(51, 52); + assertEquals(Integer.valueOf(51), biDiStream.read()); + assertEquals(Integer.valueOf(52), biDiStream.read()); + assertThat(System.currentTimeMillis() - start).isLessThan(DELAY_MILLIS); + } + + @Test + public void testReadsAndWritesInterleaved_withBlocking() throws Exception { + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + Integer[] valuesOut = {1001, 10022, 1003}; + Integer[] valuesIn = new Integer[valuesOut.length]; + delayedAddValue(300, valuesOut); + int iteration = 0; + for (int i = 0; i < valuesOut.length && iteration++ < (20 + valuesOut.length); ) { + try { + if ((valuesIn[i] = biDiStream.read(50, TimeUnit.MILLISECONDS)) != null) { + i++; + } + } catch (TimeoutException e) { + logger.info("Read timed out for " + i); + } + } + assertArrayEquals(valuesOut, valuesIn); + } + + @Test + public void testReadsAndWritesInterleaved_BlockingWrites() throws Exception { + testMethod.disableAutoRequest(); + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + testMethod.sendValuesToClient(10, 11, 12); + delayedWriteEnable(500); + long start = System.currentTimeMillis(); + boolean done = false; + int count = 0; + while (!done) { + count++; + if (!biDiStream.isWriteReady() && biDiStream.isReadReady()) { + biDiStream.read(100, TimeUnit.MILLISECONDS); + } else { + done = biDiStream.write(100, 1, TimeUnit.SECONDS); + } + } + assertEquals(4, count); + assertThat(System.currentTimeMillis() - start).isLessThan(700); + + testMethod.sendValuesToClient(20, 21, 22); + delayedWriteEnable(100); + while (!biDiStream.isWriteReady()) { + Thread.sleep(20); + } + + assertTrue(biDiStream.write(1000, 2 * DELAY_MILLIS, TimeUnit.MILLISECONDS)); + + assertEquals(Integer.valueOf(20), biDiStream.read(200, TimeUnit.MILLISECONDS)); + assertEquals(Integer.valueOf(21), biDiStream.read(200, TimeUnit.MILLISECONDS)); + assertEquals(Integer.valueOf(22), biDiStream.read(200, TimeUnit.MILLISECONDS)); + try { + Integer value = biDiStream.read(200, TimeUnit.MILLISECONDS); + fail("Unexpected read success instead of timeout. Value was: " + value); + } catch (TimeoutException ignore) { + // ignore since expected + } + } + + @Test + public void testWriteAfterCloseThrows() throws Exception { + testMethod.disableAutoRequest(); + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + // verify new writes throw an illegalStateException + biDiStream.halfClose(); + try { + assertFalse(biDiStream.write(2)); + fail("write did not throw an exception when called after halfClose"); + } catch (IllegalStateException e) { + assertThat(e.getMessage()).containsMatch("after.*halfClose.*cancel"); + } + } + + @Test + public void testClose_withException() throws Exception { + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + String descr = "too many small numbers"; + testMethod.sendError( + Status.FAILED_PRECONDITION.withDescription(descr).asRuntimeException()); + Status closedStatus = biDiStream.getClosedStatus(); + assertEquals(Code.FAILED_PRECONDITION, closedStatus.getCode()); + assertEquals(descr, closedStatus.getDescription()); + try { + assertFalse(biDiStream.write(1)); + } catch (StatusException e) { + assertThat(e.getMessage()).startsWith("FAILED_PRECONDITION"); + } + } + + private void delayedAddValue(int delayMillis, Integer... values) { + new Thread("delayedAddValue " + values.length) { + @Override + public void run() { + try { + Thread.sleep(delayMillis); + for (Integer cur : values) { + testMethod.sendValueToClient(cur); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }.start(); + } + + public interface Thunk { void apply(); } // supports passing void method w/out args + + private void delayedVoidMethod(int delayMillis, Thunk method) { + new Thread("delayedHalfClose") { + @Override + public void run() { + try { + Thread.sleep(delayMillis); + method.apply(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }.start(); + } + + private void delayedWriteEnable(int delayMillis) { + delayedVoidMethod(delayMillis, testMethod::readValueFromClient); + } + + private void delayedCancel(BlockingClientCall biDiStream, String message) { + new Thread("delayedCancel") { + @Override + public void run() { + try { + Thread.sleep(BlockingClientCallTest.DELAY_MILLIS); + biDiStream.cancel(message, new RuntimeException("Test requested close")); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }.start(); + } + + private static class IntegerTestMethod implements BidiStreamingMethod { + boolean autoRequest = true; + + void disableAutoRequest() { + assertNull("Can't disable auto request after invoke has been called", serverCallObserver); + autoRequest = false; + } + + ServerCallStreamObserver serverCallObserver; + + @Override + public StreamObserver invoke(StreamObserver responseObserver) { + serverCallObserver = (ServerCallStreamObserver) responseObserver; + if (!autoRequest) { + serverCallObserver.disableAutoRequest(); + } + + return new StreamObserver() { + @Override + public void onNext(Integer value) { + if (!autoRequest) { + serverCallObserver.request(1); + } + + // For testing ReqResp actions + if (value > 1000) { + serverCallObserver.onNext(value); + } + } + + @Override + public void onError(Throwable t) { + // no-op + } + + @Override + public void onCompleted() { + serverCallObserver.onCompleted(); + } + }; + } + + void readValueFromClient() { + serverCallObserver.request(1); + } + + void sendValueToClient(int value) { + serverCallObserver.onNext(value); + } + + private void sendValuesToClient(int ...values) { + for (int cur : values) { + sendValueToClient(cur); + } + } + + void halfClose() { + serverCallObserver.onCompleted(); + } + + void sendError(Throwable t) { + serverCallObserver.onError(t); + } + } + +} diff --git a/stub/src/test/java/io/grpc/stub/ClientCallsTest.java b/stub/src/test/java/io/grpc/stub/ClientCallsTest.java index f3d101b862a..b711b2a23b5 100644 --- a/stub/src/test/java/io/grpc/stub/ClientCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ClientCallsTest.java @@ -971,8 +971,8 @@ public ClientCall interceptCall( } @Override public void halfClose() { - Thread.currentThread().interrupt(); super.halfClose(); + Thread.currentThread().interrupt(); } }; } diff --git a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java index 1e51ac10110..6f458facc5e 100644 --- a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java @@ -555,6 +555,35 @@ public void invoke(Integer req, StreamObserver responseObserver) { listener.onHalfClose(); } + @Test + public void clientSendsOne_serverOnErrorWithTrailers_serverStreaming() { + Metadata trailers = new Metadata(); + Metadata.Key key = Metadata.Key.of("trailers-test-key1", + Metadata.ASCII_STRING_MARSHALLER); + trailers.put(key, "trailers-test-value1"); + + ServerCallRecorder serverCall = new ServerCallRecorder(SERVER_STREAMING_METHOD); + ServerCallHandler callHandler = ServerCalls.asyncServerStreamingCall( + new ServerCalls.ServerStreamingMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + responseObserver.onError( + Status.fromCode(Status.Code.INTERNAL) + .asRuntimeException(trailers) + ); + } + }); + ServerCall.Listener listener = callHandler.startCall(serverCall, new Metadata()); + serverCall.isReady = true; + serverCall.isCancelled = false; + listener.onReady(); + listener.onMessage(1); + listener.onHalfClose(); + // verify trailers key is set + assertTrue(serverCall.trailers.containsKey(key)); + assertTrue(serverCall.status.equals(Status.INTERNAL)); + } + @Test public void inprocessTransportManualFlow() throws Exception { final Semaphore semaphore = new Semaphore(1); @@ -652,6 +681,7 @@ private static class ServerCallRecorder extends ServerCall { private boolean isCancelled; private boolean isReady; private int onReadyThreshold; + private Metadata trailers; public ServerCallRecorder(MethodDescriptor methodDescriptor) { this.methodDescriptor = methodDescriptor; @@ -674,6 +704,7 @@ public void sendMessage(Integer message) { @Override public void close(Status status, Metadata trailers) { this.status = status; + this.trailers = trailers; } @Override diff --git a/stub/src/test/java/io/grpc/stub/StreamObserversTest.java b/stub/src/test/java/io/grpc/stub/StreamObserversTest.java new file mode 100644 index 00000000000..237dd2e1434 --- /dev/null +++ b/stub/src/test/java/io/grpc/stub/StreamObserversTest.java @@ -0,0 +1,38 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.stub; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.InOrder; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class StreamObserversTest { + + @Test + public void nextAndComplete() { + @SuppressWarnings("unchecked") + StreamObserver observer = Mockito.mock(StreamObserver.class); + InOrder inOrder = Mockito.inOrder(observer); + StreamObservers.nextAndComplete(observer, "TEST"); + inOrder.verify(observer).onNext("TEST"); + inOrder.verify(observer).onCompleted(); + inOrder.verifyNoMoreInteractions(); + } +} diff --git a/testing-proto/BUILD.bazel b/testing-proto/BUILD.bazel new file mode 100644 index 00000000000..aa0fc9ee20b --- /dev/null +++ b/testing-proto/BUILD.bazel @@ -0,0 +1,22 @@ +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") +load("//:java_grpc_library.bzl", "java_grpc_library") + +proto_library( + name = "simpleservice_proto", + srcs = ["src/main/proto/io/grpc/testing/protobuf/simpleservice.proto"], + strip_import_prefix = "src/main/proto/", +) + +java_proto_library( + name = "simpleservice_java_proto", + visibility = ["//xds:__pkg__"], + deps = [":simpleservice_proto"], +) + +java_grpc_library( + name = "simpleservice_java_grpc", + srcs = [":simpleservice_proto"], + visibility = ["//xds:__pkg__"], + deps = [":simpleservice_java_proto"], +) diff --git a/testing-proto/build.gradle b/testing-proto/build.gradle index e6afce468f0..ee602bc5135 100644 --- a/testing-proto/build.gradle +++ b/testing-proto/build.gradle @@ -17,10 +17,12 @@ tasks.named("jar").configure { dependencies { api project(':grpc-protobuf'), project(':grpc-stub') - compileOnly libraries.javax.annotation testImplementation libraries.truth - testRuntimeOnly libraries.javax.annotation - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } configureProtoCompilation() diff --git a/testing-proto/src/generated/main/grpc/io/grpc/testing/protobuf/SimpleServiceGrpc.java b/testing-proto/src/generated/main/grpc/io/grpc/testing/protobuf/SimpleServiceGrpc.java index 8c58f2c5a2c..e242fd0f513 100644 --- a/testing-proto/src/generated/main/grpc/io/grpc/testing/protobuf/SimpleServiceGrpc.java +++ b/testing-proto/src/generated/main/grpc/io/grpc/testing/protobuf/SimpleServiceGrpc.java @@ -7,9 +7,6 @@ * A simple service for test. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: io/grpc/testing/protobuf/simpleservice.proto") @io.grpc.stub.annotations.GrpcGenerated public final class SimpleServiceGrpc { @@ -156,6 +153,21 @@ public SimpleServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions ca return SimpleServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static SimpleServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public SimpleServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new SimpleServiceBlockingV2Stub(channel, callOptions); + } + }; + return SimpleServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -318,6 +330,72 @@ public io.grpc.stub.StreamObserver bidiS * A simple service for test. * */ + public static final class SimpleServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private SimpleServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected SimpleServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new SimpleServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Simple unary RPC.
+     * 
+ */ + public io.grpc.testing.protobuf.SimpleResponse unaryRpc(io.grpc.testing.protobuf.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryRpcMethod(), getCallOptions(), request); + } + + /** + *
+     * Simple client-to-server streaming RPC.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + clientStreamingRpc() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getClientStreamingRpcMethod(), getCallOptions()); + } + + /** + *
+     * Simple server-to-client streaming RPC.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + serverStreamingRpc(io.grpc.testing.protobuf.SimpleRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getServerStreamingRpcMethod(), getCallOptions(), request); + } + + /** + *
+     * Simple bidirectional streaming RPC.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + bidiStreamingRpc() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getBidiStreamingRpcMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service SimpleService. + *
+   * A simple service for test.
+   * 
+ */ public static final class SimpleServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private SimpleServiceBlockingStub( diff --git a/testing/BUILD.bazel b/testing/BUILD.bazel index 78f9b840754..d280ab97ee1 100644 --- a/testing/BUILD.bazel +++ b/testing/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( diff --git a/testing/build.gradle b/testing/build.gradle index cc83b7ad620..b92e39279c6 100644 --- a/testing/build.gradle +++ b/testing/build.gradle @@ -24,8 +24,16 @@ dependencies { testImplementation project(':grpc-testing-proto'), testFixtures(project(':grpc-core')) - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("javadoc").configure { exclude 'io/grpc/internal/**' } diff --git a/testing/src/main/java/io/grpc/internal/testing/FakeNameResolverProvider.java b/testing/src/main/java/io/grpc/internal/testing/FakeNameResolverProvider.java index 4664dbcc436..c77f7f8945a 100644 --- a/testing/src/main/java/io/grpc/internal/testing/FakeNameResolverProvider.java +++ b/testing/src/main/java/io/grpc/internal/testing/FakeNameResolverProvider.java @@ -21,6 +21,7 @@ import io.grpc.NameResolver; import io.grpc.NameResolverProvider; import io.grpc.Status; +import io.grpc.StatusOr; import java.net.SocketAddress; import java.net.URI; import java.util.Collection; @@ -81,9 +82,10 @@ public void start(Listener2 listener) { if (shutdown) { listener.onError(Status.FAILED_PRECONDITION.withDescription("Resolver is shutdown")); } else { - listener.onResult( + listener.onResult2( ResolutionResult.newBuilder() - .setAddresses(ImmutableList.of(new EquivalentAddressGroup(address))) + .setAddressesOrError( + StatusOr.fromValue(ImmutableList.of(new EquivalentAddressGroup(address)))) .build()); } } diff --git a/testing/src/main/resources/certs/README b/testing/src/main/resources/certs/README index 1fa6b733950..13e375784c5 100644 --- a/testing/src/main/resources/certs/README +++ b/testing/src/main/resources/certs/README @@ -67,6 +67,35 @@ ecdsa.key is used to test keys with algorithm other than RSA: $ openssl ecparam -name secp256k1 -genkey -noout -out ecdsa.pem $ openssl pkcs8 -topk8 -in ecdsa.pem -out ecdsa.key -nocrypt +SPIFFE test credentials: +======================= + +The SPIFFE related extensions are listed in spiffe-openssl.cnf config. Both +client_spiffe.pem and server1_spiffe.pem are generated in the same way with +original client.pem and server1.pem but with using that config. Here are the +exact commands (we pass "-subj" as argument in this case): +---------------------- +$ openssl req -new -key client.key -out spiffe-cert.csr \ + -subj /C=US/ST=CA/L=SVL/O=gRPC/CN=testclient/ \ + -config spiffe-openssl.cnf -reqexts spiffe_client_e2e +$ openssl x509 -req -CA ca.pem -CAkey ca.key -CAcreateserial \ + -in spiffe-cert.csr -out client_spiffe.pem -extensions spiffe_client_e2e \ + -extfile spiffe-openssl.cnf -days 3650 -sha256 +$ openssl req -new -key server1.key -out spiffe-cert.csr \ + -subj /C=US/ST=CA/L=SVL/O=gRPC/CN=*.test.google.com/ \ + -config spiffe-openssl.cnf -reqexts spiffe_server_e2e +$ openssl x509 -req -CA ca.pem -CAkey ca.key -CAcreateserial \ + -in spiffe-cert.csr -out server1_spiffe.pem -extensions spiffe_server_e2e \ + -extfile spiffe-openssl.cnf -days 3650 -sha256 + +Additionally, SPIFFE trust bundle map files spiffebundle.json and \ +spiffebundle1.json are manually created for end to end testing. The \ +spiffebundle.json contains "example.com" trust domain (only this entry is used \ +in e2e tests) matching URI SAN of server1_spiffe.pem, and the CA certificate \ +there is ca.pem. The spiffebundle.json file contains "foo.bar.com" trust \ +domain (only this entry is used in e2e tests) matching URI SAN of \ +client_spiffe.pem, and the CA certificate there is also ca.pem. + Clean up: --------- $ rm *.rsa diff --git a/testing/src/main/resources/certs/client_spiffe.pem b/testing/src/main/resources/certs/client_spiffe.pem new file mode 100644 index 00000000000..c70981a4030 --- /dev/null +++ b/testing/src/main/resources/certs/client_spiffe.pem @@ -0,0 +1,25 @@ +-----BEGIN CERTIFICATE----- +MIIEMjCCAxqgAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShUwDQYJKoZIhvcNAQEL +BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 +MTAyNDE2NDAzN1oXDTM0MTAyMjE2NDAzN1owaDELMAkGA1UEBhMCQVUxEzARBgNV +BAgMClNvbWUtU3RhdGUxDDAKBgNVBAcMA1NWTDEhMB8GA1UECgwYSW50ZXJuZXQg +V2lkZ2l0cyBQdHkgTHRkMRMwEQYDVQQDDAp0ZXN0Y2xpZW50MIIBIjANBgkqhkiG +9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsqmEafg11ae9jRW0B/IXYU2S8nGVzpSYZjLK +yZq459qe6SP/Jk2f9BQvkhlgRmVfhC4h65gl+c32iC6/SLsOxoa91c6Hn4vK+tqy +7qVTzYv6naso1pNnRAhwvWd/gINysyk8nq11oynL8ilZjNGcRNEV4Q1v0aEG6mbF +NhioNQdq4VFPCjdIFZip9KyRzsc0VUmHmC2KeWJ+yq7TyXCsqPWlbhK+3RgDc6ch +epYP52AVnPvUhsJKC3RbyrwAWCTMq2zYR1EH79H82mdD/OnX0xDaw8cwC68xp6nM +dyk68CY5Gf2kq9bcg9P7V77pERYj8VgSYYx0O9BqkxUGNfUW4QIDAQABo4HlMIHi +MEQGA1UdEQQ9MDuGOXNwaWZmZTovL2Zvby5iYXIuY29tLzllZWJjY2QyLTEyYmYt +NDBhNi1iMjYyLTY1ZmUwNDg3ZDQ1MzAdBgNVHQ4EFgQU28U8sUTGNEDyeCrvJDJd +AALabSMwewYDVR0jBHQwcqFapFgwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNv +bWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0G +A1UEAwwGdGVzdGNhghRas/RW8dzL4s/pS5g22Iv2AGEPmjANBgkqhkiG9w0BAQsF +AAOCAQEAE3LLE8GR283q/aE646SgAfltqpESP38NmYdJMdZgWRxbOqdWabYDfibt +9r8j+IRvVuuTWuH2eNS5wXJtS1BZ+z24wTLa+a2KjOV12gChP+3N7jhqId4eolSL +1fjscPY6luZP4Pm3D73lBvIoBvXpDGyrxleiUCEEkKXmTOA8doFvbrcbwm+yUJOP +VKUKvAzTNztb0BGDzKKU4E2yK5PSyv2n5m2NpzxYYfHoGeVcxvj7nCnSfoX/EWHb +d8ztJYDg9X0iNcfQXt7PZ+j6VcxfDpGCDxe2rFQoYvlWjhr3xOi/1e5A1zx1Ly07 +m9MB4hntu4e2656ZDWbgOHLpO0q1iQ== +-----END CERTIFICATE----- diff --git a/testing/src/main/resources/certs/server1_spiffe.pem b/testing/src/main/resources/certs/server1_spiffe.pem new file mode 100644 index 00000000000..76cb41d6922 --- /dev/null +++ b/testing/src/main/resources/certs/server1_spiffe.pem @@ -0,0 +1,26 @@ +-----BEGIN CERTIFICATE----- +MIIEZDCCA0ygAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShMwDQYJKoZIhvcNAQEL +BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 +MTAyMTAyMTQxNVoXDTM0MTAxOTAyMTQxNVowZTELMAkGA1UEBhMCVVMxETAPBgNV +BAgMCElsbGlub2lzMRAwDgYDVQQHDAdDaGljYWdvMRUwEwYDVQQKDAxFeGFtcGxl +LCBDby4xGjAYBgNVBAMMESoudGVzdC5nb29nbGUuY29tMIIBIjANBgkqhkiG9w0B +AQEFAAOCAQ8AMIIBCgKCAQEA5xOONxJJ8b8Qauvob5/7dPYZfIcd+uhAWL2ZlTPz +Qvu4oF0QI4iYgP5iGgry9zEtCM+YQS8UhiAlPlqa6ANxgiBSEyMHH/xE8lo/+caY +GeACqy640Jpl/JocFGo3xd1L8DCawjlaj6eu7T7T/tpAV2qq13b5710eNRbCAfFe +8yALiGQemx0IYhlZXNbIGWLBNhBhvVjJh7UvOqpADk4xtl8o5j0xgMIRg6WJGK6c +6ffSIg4eP1XmovNYZ9LLEJG68tF0Q/yIN43B4dt1oq4jzSdCbG4F1EiykT2TmwPV +YDi8tml6DfOCDGnit8svnMEmBv/fcPd31GSbXjF8M+KGGQIDAQABo4IBGTCCARUw +dwYDVR0RBHAwboIQKi50ZXN0Lmdvb2dsZS5mcoIYd2F0ZXJ6b29pLnRlc3QuZ29v +Z2xlLmJlghIqLnRlc3QueW91dHViZS5jb22HBMCoAQOGJnNwaWZmZTovL2V4YW1w +bGUuY29tL3dvcmtsb2FkLzllZWJjY2QyMB0GA1UdDgQWBBRvRpAYHQYP6dFPf5V7 +/MyCftnNjTB7BgNVHSMEdDByoVqkWDBWMQswCQYDVQQGEwJBVTETMBEGA1UECAwK +U29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMQ8w +DQYDVQQDDAZ0ZXN0Y2GCFFqz9Fbx3Mviz+lLmDbYi/YAYQ+aMA0GCSqGSIb3DQEB +CwUAA4IBAQBJ1bnbBHa1n15vvhpGIzokuiJ+9q/zim63UuVDnkhrQM2N+RQbStGT +Tis2tNse1bh460dJFm6ArgHWogzx6fQZzgaDeCOAXvrAe4jM9IHr9K7lkq/33CZS +BDV+jCmm2sRsqSMkKUcX6JhyqWGFHuTDAKJzsEV2MlcswleKlGHDkeelAaxlLzpz +RHOSQd0N9xAs18lzx95SQEx90PtrBOmvIDDiI5o5z9Oz12Iy1toiksFl4jmknkDD +5VF3AyCRgN8NPW0uNC8D2vo4L+tgj9U6NPlmMOrjRsEH257LJ1wopAGr+yezkIId +QQodGSVm5cOuw/K7Ma4nBDjVJkjcdY3t +-----END CERTIFICATE----- diff --git a/testing/src/main/resources/certs/spiffe-openssl.cnf b/testing/src/main/resources/certs/spiffe-openssl.cnf new file mode 100644 index 00000000000..f03af40a782 --- /dev/null +++ b/testing/src/main/resources/certs/spiffe-openssl.cnf @@ -0,0 +1,28 @@ +[spiffe_client] +subjectAltName = @alt_names + +[spiffe_client_multi] +subjectAltName = @alt_names_multi + +[spiffe_server_e2e] +subjectAltName = @alt_names_server_e2e + +[spiffe_client_e2e] +subjectAltName = @alt_names_client_e2e + +[alt_names] +URI = spiffe://foo.bar.com/client/workload/1 + +[alt_names_multi] +URI.1 = spiffe://foo.bar.com/client/workload/1 +URI.2 = spiffe://foo.bar.com/client/workload/2 + +[alt_names_server_e2e] +DNS.1 = *.test.google.fr +DNS.2 = waterzooi.test.google.be +DNS.3 = *.test.youtube.com +IP.1 = "192.168.1.3" +URI = spiffe://example.com/workload/9eebccd2 + +[alt_names_client_e2e] +URI = spiffe://foo.bar.com/9eebccd2-12bf-40a6-b262-65fe0487d453 \ No newline at end of file diff --git a/testing/src/main/resources/certs/spiffe_cert.pem b/testing/src/main/resources/certs/spiffe_cert.pem new file mode 100644 index 00000000000..bc070042f69 --- /dev/null +++ b/testing/src/main/resources/certs/spiffe_cert.pem @@ -0,0 +1,33 @@ +-----BEGIN CERTIFICATE----- +MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL +BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL +BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy +NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM +MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu +dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ +LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z +G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO +a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z +JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV +m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 +7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc +msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc +DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN +zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs +vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI +sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH +HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP +BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t +L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ +aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 +4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 +IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ +PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV ++j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D +vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq +yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV +z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx +x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U +0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX +GA91fn0891b5eEW8BJHXX0jri0aN8g== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/testing/src/main/resources/certs/spiffe_multi_uri_san_cert.pem b/testing/src/main/resources/certs/spiffe_multi_uri_san_cert.pem new file mode 100644 index 00000000000..eb5c879abf8 --- /dev/null +++ b/testing/src/main/resources/certs/spiffe_multi_uri_san_cert.pem @@ -0,0 +1,25 @@ +-----BEGIN CERTIFICATE----- +MIIELTCCAxWgAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShEwDQYJKoZIhvcNAQEL +BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 +MDkxNzE2MTk0NFoXDTM0MDkxNTE2MTk0NFowTjELMAkGA1UEBhMCVVMxCzAJBgNV +BAgMAkNBMQwwCgYDVQQHDANTVkwxDTALBgNVBAoMBGdSUEMxFTATBgNVBAMMDHRl +c3QtY2xpZW50MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOcTjjcS +SfG/EGrr6G+f+3T2GXyHHfroQFi9mZUz80L7uKBdECOImID+YhoK8vcxLQjPmEEv +FIYgJT5amugDcYIgUhMjBx/8RPJaP/nGmBngAqsuuNCaZfyaHBRqN8XdS/AwmsI5 +Wo+nru0+0/7aQFdqqtd2+e9dHjUWwgHxXvMgC4hkHpsdCGIZWVzWyBliwTYQYb1Y +yYe1LzqqQA5OMbZfKOY9MYDCEYOliRiunOn30iIOHj9V5qLzWGfSyxCRuvLRdEP8 +iDeNweHbdaKuI80nQmxuBdRIspE9k5sD1WA4vLZpeg3zggxp4rfLL5zBJgb/33D3 +d9Rkm14xfDPihhkCAwEAAaOB+jCB9zBZBgNVHREEUjBQhiZzcGlmZmU6Ly9mb28u +YmFyLmNvbS9jbGllbnQvd29ya2xvYWQvMYYmc3BpZmZlOi8vZm9vLmJhci5jb20v +Y2xpZW50L3dvcmtsb2FkLzIwHQYDVR0OBBYEFG9GkBgdBg/p0U9/lXv8zIJ+2c2N +MHsGA1UdIwR0MHKhWqRYMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 +YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMM +BnRlc3RjYYIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQELBQADggEB +AJ4Cbxv+02SpUgkEu4hP/1+8DtSBXUxNxI0VG4e3Ap2+Rhjm3YiFeS/UeaZhNrrw +UEjkSTPFODyXR7wI7UO9OO1StyD6CMkp3SEvevU5JsZtGL6mTiTLTi3Qkywa91Bt +GlyZdVMghA1bBJLBMwiD5VT5noqoJBD7hDy6v9yNmt1Sw2iYBJPqI3Gnf5bMjR3s +UICaxmFyqaMCZsPkfJh0DmZpInGJys3m4QqGz6ZE2DWgcSr1r/ML7/5bSPjjr8j4 +WFFSqFR3dMu8CbGnfZTCTXa4GTX/rARXbAO67Z/oJbJBK7VKayskL+PzKuohb9ox +jGL772hQMbwtFCOFXu5VP0s= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/testing/src/main/resources/certs/spiffebundle.json b/testing/src/main/resources/certs/spiffebundle.json new file mode 100644 index 00000000000..5bc8fcfb432 --- /dev/null +++ b/testing/src/main/resources/certs/spiffebundle.json @@ -0,0 +1,101 @@ +{ + "trust_domains": { + "example.com": { + "spiffe_sequence": 12035488, + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIDWjCCAkKgAwIBAgIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTIw + MDMxNzE4NTk1MVoXDTMwMDMxNTE4NTk1MVowVjELMAkGA1UEBhMCQVUxEzARBgNV + BAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0 + ZDEPMA0GA1UEAwwGdGVzdGNhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC + AQEAsGL0oXflF0LzoM+Bh+qUU9yhqzw2w8OOX5mu/iNCyUOBrqaHi7mGHx73GD01 + diNzCzvlcQqdNIH6NQSL7DTpBjca66jYT9u73vZe2MDrr1nVbuLvfu9850cdxiUO + Inv5xf8+sTHG0C+a+VAvMhsLiRjsq+lXKRJyk5zkbbsETybqpxoJ+K7CoSy3yc/k + QIY3TipwEtwkKP4hzyo6KiGd/DPexie4nBUInN3bS1BUeNZ5zeaIC2eg3bkeeW7c + qT55b+Yen6CxY0TEkzBK6AKt/WUialKMgT0wbTxRZO7kUCH3Sq6e/wXeFdJ+HvdV + LPlAg5TnMaNpRdQih/8nRFpsdwIDAQABoyAwHjAMBgNVHRMEBTADAQH/MA4GA1Ud + DwEB/wQEAwICBDANBgkqhkiG9w0BAQsFAAOCAQEAkTrKZjBrJXHps/HrjNCFPb5a + THuGPCSsepe1wkKdSp1h4HGRpLoCgcLysCJ5hZhRpHkRihhef+rFHEe60UePQO3S + CVTtdJB4CYWpcNyXOdqefrbJW5QNljxgi6Fhvs7JJkBqdXIkWXtFk2eRgOIP2Eo9 + /OHQHlYnwZFrk6sp4wPyR+A95S0toZBcyDVz7u+hOW0pGK3wviOe9lvRgj/H3Pwt + bewb0l+MhRig0/DVHamyVxrDRbqInU1/GTNCwcZkXKYFWSf92U+kIcTth24Q1gcw + eZiLl5FfrWokUNytFElXob0V0a5/kbhiLc3yWmvWqHTpqCALbVyF+rKJo2f5Kw=="], + "n": "", + "e": "AQAB" + } + ] + }, + "test.example.com": { + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL + BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL + BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy + NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM + MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu + dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ + LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z + G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO + a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z + JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV + m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 + 7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc + msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc + DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN + zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs + vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI + sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH + HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP + BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t + L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ + aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 + 4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 + IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ + PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV + +j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D + vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq + yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV + z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx + x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U + 0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX + GA91fn0891b5eEW8BJHXX0jri0aN8g=="], + "n": "", + "e": "AQAB" + }, + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIELTCCAxWgAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShEwDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 + MDkxNzE2MTk0NFoXDTM0MDkxNTE2MTk0NFowTjELMAkGA1UEBhMCVVMxCzAJBgNV + BAgMAkNBMQwwCgYDVQQHDANTVkwxDTALBgNVBAoMBGdSUEMxFTATBgNVBAMMDHRl + c3QtY2xpZW50MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOcTjjcS + SfG/EGrr6G+f+3T2GXyHHfroQFi9mZUz80L7uKBdECOImID+YhoK8vcxLQjPmEEv + FIYgJT5amugDcYIgUhMjBx/8RPJaP/nGmBngAqsuuNCaZfyaHBRqN8XdS/AwmsI5 + Wo+nru0+0/7aQFdqqtd2+e9dHjUWwgHxXvMgC4hkHpsdCGIZWVzWyBliwTYQYb1Y + yYe1LzqqQA5OMbZfKOY9MYDCEYOliRiunOn30iIOHj9V5qLzWGfSyxCRuvLRdEP8 + iDeNweHbdaKuI80nQmxuBdRIspE9k5sD1WA4vLZpeg3zggxp4rfLL5zBJgb/33D3 + d9Rkm14xfDPihhkCAwEAAaOB+jCB9zBZBgNVHREEUjBQhiZzcGlmZmU6Ly9mb28u + YmFyLmNvbS9jbGllbnQvd29ya2xvYWQvMYYmc3BpZmZlOi8vZm9vLmJhci5jb20v + Y2xpZW50L3dvcmtsb2FkLzIwHQYDVR0OBBYEFG9GkBgdBg/p0U9/lXv8zIJ+2c2N + MHsGA1UdIwR0MHKhWqRYMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 + YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMM + BnRlc3RjYYIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQELBQADggEB + AJ4Cbxv+02SpUgkEu4hP/1+8DtSBXUxNxI0VG4e3Ap2+Rhjm3YiFeS/UeaZhNrrw + UEjkSTPFODyXR7wI7UO9OO1StyD6CMkp3SEvevU5JsZtGL6mTiTLTi3Qkywa91Bt + GlyZdVMghA1bBJLBMwiD5VT5noqoJBD7hDy6v9yNmt1Sw2iYBJPqI3Gnf5bMjR3s + UICaxmFyqaMCZsPkfJh0DmZpInGJys3m4QqGz6ZE2DWgcSr1r/ML7/5bSPjjr8j4 + WFFSqFR3dMu8CbGnfZTCTXa4GTX/rARXbAO67Z/oJbJBK7VKayskL+PzKuohb9ox + jGL772hQMbwtFCOFXu5VP0s="] + } + ] + } + } +} \ No newline at end of file diff --git a/testing/src/main/resources/certs/spiffebundle1.json b/testing/src/main/resources/certs/spiffebundle1.json new file mode 100644 index 00000000000..f79af09a3e7 --- /dev/null +++ b/testing/src/main/resources/certs/spiffebundle1.json @@ -0,0 +1,59 @@ +{ + "trust_domains": { + "example.com": { + "spiffe_sequence": 12035488, + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIDWjCCAkKgAwIBAgIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTIw + MDMxNzE4NTk1MVoXDTMwMDMxNTE4NTk1MVowVjELMAkGA1UEBhMCQVUxEzARBgNV + BAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0 + ZDEPMA0GA1UEAwwGdGVzdGNhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC + AQEAsGL0oXflF0LzoM+Bh+qUU9yhqzw2w8OOX5mu/iNCyUOBrqaHi7mGHx73GD01 + diNzCzvlcQqdNIH6NQSL7DTpBjca66jYT9u73vZe2MDrr1nVbuLvfu9850cdxiUO + Inv5xf8+sTHG0C+a+VAvMhsLiRjsq+lXKRJyk5zkbbsETybqpxoJ+K7CoSy3yc/k + QIY3TipwEtwkKP4hzyo6KiGd/DPexie4nBUInN3bS1BUeNZ5zeaIC2eg3bkeeW7c + qT55b+Yen6CxY0TEkzBK6AKt/WUialKMgT0wbTxRZO7kUCH3Sq6e/wXeFdJ+HvdV + LPlAg5TnMaNpRdQih/8nRFpsdwIDAQABoyAwHjAMBgNVHRMEBTADAQH/MA4GA1Ud + DwEB/wQEAwICBDANBgkqhkiG9w0BAQsFAAOCAQEAkTrKZjBrJXHps/HrjNCFPb5a + THuGPCSsepe1wkKdSp1h4HGRpLoCgcLysCJ5hZhRpHkRihhef+rFHEe60UePQO3S + CVTtdJB4CYWpcNyXOdqefrbJW5QNljxgi6Fhvs7JJkBqdXIkWXtFk2eRgOIP2Eo9 + /OHQHlYnwZFrk6sp4wPyR+A95S0toZBcyDVz7u+hOW0pGK3wviOe9lvRgj/H3Pwt + bewb0l+MhRig0/DVHamyVxrDRbqInU1/GTNCwcZkXKYFWSf92U+kIcTth24Q1gcw + eZiLl5FfrWokUNytFElXob0V0a5/kbhiLc3yWmvWqHTpqCALbVyF+rKJo2f5Kw=="], + "n": "", + "e": "AQAB" + } + ] + }, + "foo.bar.com": { + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIDWjCCAkKgAwIBAgIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTIw + MDMxNzE4NTk1MVoXDTMwMDMxNTE4NTk1MVowVjELMAkGA1UEBhMCQVUxEzARBgNV + BAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0 + ZDEPMA0GA1UEAwwGdGVzdGNhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC + AQEAsGL0oXflF0LzoM+Bh+qUU9yhqzw2w8OOX5mu/iNCyUOBrqaHi7mGHx73GD01 + diNzCzvlcQqdNIH6NQSL7DTpBjca66jYT9u73vZe2MDrr1nVbuLvfu9850cdxiUO + Inv5xf8+sTHG0C+a+VAvMhsLiRjsq+lXKRJyk5zkbbsETybqpxoJ+K7CoSy3yc/k + QIY3TipwEtwkKP4hzyo6KiGd/DPexie4nBUInN3bS1BUeNZ5zeaIC2eg3bkeeW7c + qT55b+Yen6CxY0TEkzBK6AKt/WUialKMgT0wbTxRZO7kUCH3Sq6e/wXeFdJ+HvdV + LPlAg5TnMaNpRdQih/8nRFpsdwIDAQABoyAwHjAMBgNVHRMEBTADAQH/MA4GA1Ud + DwEB/wQEAwICBDANBgkqhkiG9w0BAQsFAAOCAQEAkTrKZjBrJXHps/HrjNCFPb5a + THuGPCSsepe1wkKdSp1h4HGRpLoCgcLysCJ5hZhRpHkRihhef+rFHEe60UePQO3S + CVTtdJB4CYWpcNyXOdqefrbJW5QNljxgi6Fhvs7JJkBqdXIkWXtFk2eRgOIP2Eo9 + /OHQHlYnwZFrk6sp4wPyR+A95S0toZBcyDVz7u+hOW0pGK3wviOe9lvRgj/H3Pwt + bewb0l+MhRig0/DVHamyVxrDRbqInU1/GTNCwcZkXKYFWSf92U+kIcTth24Q1gcw + eZiLl5FfrWokUNytFElXob0V0a5/kbhiLc3yWmvWqHTpqCALbVyF+rKJo2f5Kw=="] + } + ] + } + } +} \ No newline at end of file diff --git a/testing/src/test/java/io/grpc/testing/GrpcCleanupRuleTest.java b/testing/src/test/java/io/grpc/testing/GrpcCleanupRuleTest.java index a5a6783d53f..8eb3edd3825 100644 --- a/testing/src/test/java/io/grpc/testing/GrpcCleanupRuleTest.java +++ b/testing/src/test/java/io/grpc/testing/GrpcCleanupRuleTest.java @@ -18,6 +18,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; @@ -35,9 +36,7 @@ import io.grpc.internal.FakeClock; import io.grpc.testing.GrpcCleanupRule.Resource; import java.util.concurrent.TimeUnit; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.junit.runners.model.MultipleFailureException; @@ -51,10 +50,6 @@ public class GrpcCleanupRuleTest { public static final FakeClock fakeClock = new FakeClock(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Test public void registerChannelReturnSameChannel() { ManagedChannel channel = mock(ManagedChannel.class); @@ -72,10 +67,9 @@ public void registerNullChannelThrowsNpe() { ManagedChannel channel = null; GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - thrown.expect(NullPointerException.class); - thrown.expectMessage("channel"); - - grpcCleanup.register(channel); + NullPointerException e = assertThrows(NullPointerException.class, + () -> grpcCleanup.register(channel)); + assertThat(e).hasMessageThat().isEqualTo("channel"); } @Test @@ -83,10 +77,9 @@ public void registerNullServerThrowsNpe() { Server server = null; GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - thrown.expect(NullPointerException.class); - thrown.expectMessage("server"); - - grpcCleanup.register(server); + NullPointerException e = assertThrows(NullPointerException.class, + () -> grpcCleanup.register(server)); + assertThat(e).hasMessageThat().isEqualTo("server"); } @Test diff --git a/util/BUILD.bazel b/util/BUILD.bazel index 8fb00e21d56..32d5a367b95 100644 --- a/util/BUILD.bazel +++ b/util/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_library") load("@rules_jvm_external//:defs.bzl", "artifact") java_library( diff --git a/util/build.gradle b/util/build.gradle index 932ca66883e..846b110b106 100644 --- a/util/build.gradle +++ b/util/build.gradle @@ -35,8 +35,16 @@ dependencies { project(':grpc-testing') jmh project(':grpc-testing') - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } animalsniffer { @@ -50,6 +58,7 @@ animalsniffer { tasks.named("javadoc").configure { exclude 'io/grpc/util/MultiChildLoadBalancer.java' exclude 'io/grpc/util/OutlierDetectionLoadBalancer*' + exclude 'io/grpc/util/RandomSubsettingLoadBalancer*' exclude 'io/grpc/util/RoundRobinLoadBalancer*' } diff --git a/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java b/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java index 1f807cd405d..eea664f2ad4 100644 --- a/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java +++ b/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java @@ -32,6 +32,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; import java.util.logging.Logger; import javax.net.ssl.SSLEngine; @@ -40,59 +41,86 @@ /** * AdvancedTlsX509KeyManager is an {@code X509ExtendedKeyManager} that allows users to configure * advanced TLS features, such as private key and certificate chain reloading. + * + *

The alias increments on every credential load (e.g. {@code "key-1"}, {@code "key-2"}, ...), + * so the same alias always maps to the same key material. The previous alias is retained for one + * rotation to allow in-progress handshakes to complete, ensuring alias-to-key-material consistency + * across credential reloads. */ public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager { private static final Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName()); // Minimum allowed period for refreshing files with credential information. - private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1 ; - // The credential information to be sent to peers to prove our identity. - private volatile KeyInfo keyInfo; + private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1; + // Prefix for the key material alias; revision counter appended on each credential load. + static final String ALIAS_PREFIX = "key-"; + + private final AtomicInteger revision = new AtomicInteger(0); + // Snapshot of current and previous KeyInfo; previous is retained for in-progress handshakes + // after one rotation. + private volatile KeyInfoSnapshot snapshot = new KeyInfoSnapshot(null, null); + + public AdvancedTlsX509KeyManager() {} + + private String alias() { + KeyInfo curr = this.snapshot.current; + return curr != null ? curr.alias : null; + } @Override public PrivateKey getPrivateKey(String alias) { - if (alias.equals("default")) { - return this.keyInfo.key; + KeyInfoSnapshot snap = this.snapshot; + if (snap.current != null && snap.current.alias.equals(alias)) { + return snap.current.key; + } + if (snap.previous != null && snap.previous.alias.equals(alias)) { + return snap.previous.key; } return null; } @Override public X509Certificate[] getCertificateChain(String alias) { - if (alias.equals("default")) { - return Arrays.copyOf(this.keyInfo.certs, this.keyInfo.certs.length); + KeyInfoSnapshot snap = this.snapshot; + if (snap.current != null && snap.current.alias.equals(alias)) { + return Arrays.copyOf(snap.current.certs, snap.current.certs.length); + } + if (snap.previous != null && snap.previous.alias.equals(alias)) { + return Arrays.copyOf(snap.previous.certs, snap.previous.certs.length); } return null; } @Override public String[] getClientAliases(String keyType, Principal[] issuers) { - return new String[] {"default"}; + String alias = alias(); + return alias != null ? new String[] {alias} : null; } @Override public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { - return "default"; + return alias(); } @Override public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) { - return "default"; + return alias(); } @Override public String[] getServerAliases(String keyType, Principal[] issuers) { - return new String[] {"default"}; + String alias = alias(); + return alias != null ? new String[] {alias} : null; } @Override public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { - return "default"; + return alias(); } @Override public String chooseEngineServerAlias(String keyType, Principal[] issuers, SSLEngine engine) { - return "default"; + return alias(); } /** @@ -116,7 +144,9 @@ public void updateIdentityCredentials(PrivateKey key, X509Certificate[] certs) { * @param key the private key that is going to be used */ public void updateIdentityCredentials(X509Certificate[] certs, PrivateKey key) { - this.keyInfo = new KeyInfo(checkNotNull(certs, "certs"), checkNotNull(key, "key")); + KeyInfo newInfo = new KeyInfo(checkNotNull(certs, "certs"), checkNotNull(key, "key"), + ALIAS_PREFIX + revision.incrementAndGet()); + this.snapshot = new KeyInfoSnapshot(newInfo, this.snapshot.current); } /** @@ -218,10 +248,22 @@ private static class KeyInfo { // The private key and the cert chain we will use to send to peers to prove our identity. final X509Certificate[] certs; final PrivateKey key; + final String alias; - public KeyInfo(X509Certificate[] certs, PrivateKey key) { + public KeyInfo(X509Certificate[] certs, PrivateKey key, String alias) { this.certs = certs; this.key = key; + this.alias = alias; + } + } + + private static class KeyInfoSnapshot { + final KeyInfo current; + final KeyInfo previous; + + KeyInfoSnapshot(KeyInfo current, KeyInfo previous) { + this.current = current; + this.previous = previous; } } @@ -309,4 +351,3 @@ public interface Closeable extends java.io.Closeable { void close(); } } - diff --git a/util/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java b/util/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java index 088f4caa000..0739fa3d453 100644 --- a/util/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java +++ b/util/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java @@ -265,7 +265,7 @@ public Closeable updateTrustCredentials(File trustCertFile, long period, TimeUni } final ScheduledFuture future = checkNotNull(executor, "executor").scheduleWithFixedDelay( - new LoadFilePathExecution(trustCertFile), period, period, unit); + new LoadFilePathExecution(trustCertFile, updatedTime), period, period, unit); return () -> future.cancel(false); } @@ -312,9 +312,9 @@ private class LoadFilePathExecution implements Runnable { File file; long currentTime; - public LoadFilePathExecution(File file) { + public LoadFilePathExecution(File file, long currentTime) { this.file = file; - this.currentTime = 0; + this.currentTime = currentTime; } @Override @@ -339,6 +339,10 @@ public void run() { private long readAndUpdate(File trustCertFile, long oldTime) throws IOException, GeneralSecurityException { long newTime = checkNotNull(trustCertFile, "trustCertFile").lastModified(); + if (newTime == 0) { + throw new IOException( + "Certificate file not found or not readable: " + trustCertFile.getAbsolutePath()); + } if (newTime == oldTime) { return oldTime; } diff --git a/util/src/main/java/io/grpc/util/ForwardingLoadBalancer.java b/util/src/main/java/io/grpc/util/ForwardingLoadBalancer.java index cefcbf344ea..d52ff42e652 100644 --- a/util/src/main/java/io/grpc/util/ForwardingLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/ForwardingLoadBalancer.java @@ -29,6 +29,7 @@ public abstract class ForwardingLoadBalancer extends LoadBalancer { */ protected abstract LoadBalancer delegate(); + @Deprecated @Override public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { delegate().handleResolvedAddresses(resolvedAddresses); @@ -52,6 +53,8 @@ public void shutdown() { } @Override + @Deprecated + @SuppressWarnings("InlineMeSuggester") public boolean canHandleEmptyAddressListFromNameResolution() { return delegate().canHandleEmptyAddressListFromNameResolution(); } diff --git a/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java b/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java index a63a641b037..27dc080c71b 100644 --- a/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java @@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Objects; import io.grpc.ConnectivityState; @@ -38,34 +37,21 @@ /** * A load balancer that gracefully swaps to a new lb policy. If the channel is currently in a state * other than READY, the new policy will be swapped into place immediately. Otherwise, the channel - * will keep using the old policy until the new policy reports READY or the old policy exits READY. + * will keep using the old policy until the new policy leaves CONNECTING or the old policy exits + * READY. * *

The child balancer and configuration is specified using service config. Config objects are * generally created by calling {@link #parseLoadBalancingPolicyConfig(List)} from a * {@link io.grpc.LoadBalancerProvider#parseLoadBalancingPolicyConfig * provider's parseLoadBalancingPolicyConfig()} implementation. - * - *

Alternatively, the balancer may {@link #switchTo(LoadBalancer.Factory) switch to} a policy - * prior to {@link - * LoadBalancer#handleResolvedAddresses(ResolvedAddresses) handling resolved addresses} for the - * first time. This causes graceful switch to ignore the service config and pass through the - * resolved addresses directly to the child policy. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/5999") @NotThreadSafe // Must be accessed in SynchronizationContext public final class GracefulSwitchLoadBalancer extends ForwardingLoadBalancer { private final LoadBalancer defaultBalancer = new LoadBalancer() { @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { - // Most LB policies using this class will receive child policy configuration within the - // service config, so they are naturally calling switchTo() just before - // handleResolvedAddresses(), within their own handleResolvedAddresses(). If switchTo() is - // not called immediately after construction that does open up potential for bugs in the - // parent policies, where they fail to call switchTo(). So we will use the exception to try - // to notice those bugs quickly, as it will fail very loudly. - throw new IllegalStateException( - "GracefulSwitchLoadBalancer must switch to a load balancing policy before handling" - + " ResolvedAddresses"); + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + throw new AssertionError("real LB is called instead"); } @Override @@ -79,19 +65,6 @@ public void handleNameResolutionError(final Status error) { public void shutdown() {} }; - @VisibleForTesting - static final SubchannelPicker BUFFER_PICKER = new SubchannelPicker() { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withNoResult(); - } - - @Override - public String toString() { - return "BUFFER_PICKER"; - } - }; - private final Helper helper; // While the new policy is not fully switched on, the pendingLb is handling new updates from name @@ -104,7 +77,6 @@ public String toString() { private LoadBalancer pendingLb = defaultBalancer; private ConnectivityState pendingState; private SubchannelPicker pendingPicker; - private boolean switchToCalled; private boolean currentLbIsReady; @@ -112,12 +84,9 @@ public GracefulSwitchLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); } + @Deprecated @Override public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { - if (switchToCalled) { - delegate().handleResolvedAddresses(resolvedAddresses); - return; - } Config config = (Config) resolvedAddresses.getLoadBalancingPolicyConfig(); switchToInternal(config.childFactory); delegate().handleResolvedAddresses( @@ -128,9 +97,6 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { @Override public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - if (switchToCalled) { - return delegate().acceptResolvedAddresses(resolvedAddresses); - } Config config = (Config) resolvedAddresses.getLoadBalancingPolicyConfig(); switchToInternal(config.childFactory); return delegate().acceptResolvedAddresses( @@ -139,19 +105,6 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { .build()); } - /** - * Gracefully switch to a new policy defined by the given factory, if the given factory isn't - * equal to the current one. - * - * @deprecated Use {@code parseLoadBalancingPolicyConfig()} and pass the configuration to - * {@link io.grpc.LoadBalancer.ResolvedAddresses.Builder#setLoadBalancingPolicyConfig} - */ - @Deprecated - public void switchTo(LoadBalancer.Factory newBalancerFactory) { - switchToCalled = true; - switchToInternal(newBalancerFactory); - } - private void switchToInternal(LoadBalancer.Factory newBalancerFactory) { checkNotNull(newBalancerFactory, "newBalancerFactory"); @@ -162,7 +115,7 @@ private void switchToInternal(LoadBalancer.Factory newBalancerFactory) { pendingLb = defaultBalancer; pendingBalancerFactory = null; pendingState = ConnectivityState.CONNECTING; - pendingPicker = BUFFER_PICKER; + pendingPicker = new FixedResultPicker(PickResult.withNoResult()); if (newBalancerFactory.equals(currentBalancerFactory)) { return; @@ -182,7 +135,7 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne checkState(currentLbIsReady, "there's pending lb while current lb has been out of READY"); pendingState = newState; pendingPicker = newPicker; - if (newState == ConnectivityState.READY) { + if (newState != ConnectivityState.CONNECTING) { swap(); } } else if (lb == currentLb) { @@ -255,14 +208,14 @@ public static ConfigOrError parseLoadBalancingPolicyConfig( ServiceConfigUtil.unwrapLoadBalancingConfigList(loadBalancingConfigs); if (childConfigCandidates == null || childConfigCandidates.isEmpty()) { return ConfigOrError.fromError( - Status.INTERNAL.withDescription("No child LB config specified")); + Status.UNAVAILABLE.withDescription("No child LB config specified")); } ConfigOrError selectedConfig = ServiceConfigUtil.selectLbPolicyFromList(childConfigCandidates, lbRegistry); if (selectedConfig.getError() != null) { Status error = selectedConfig.getError(); return ConfigOrError.fromError( - Status.INTERNAL + Status.UNAVAILABLE .withCause(error.getCause()) .withDescription(error.getDescription()) .augmentDescription("Failed to select child config")); diff --git a/util/src/main/java/io/grpc/util/HealthProducerHelper.java b/util/src/main/java/io/grpc/util/HealthProducerHelper.java index b11864765ea..d871911d203 100644 --- a/util/src/main/java/io/grpc/util/HealthProducerHelper.java +++ b/util/src/main/java/io/grpc/util/HealthProducerHelper.java @@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Internal; import io.grpc.LoadBalancer; @@ -84,6 +85,31 @@ protected LoadBalancer.Helper delegate() { return delegate; } + @Override + public void updateBalancingState( + ConnectivityState newState, LoadBalancer.SubchannelPicker newPicker) { + delegate.updateBalancingState(newState, new HealthProducerPicker(newPicker)); + } + + private static final class HealthProducerPicker extends LoadBalancer.SubchannelPicker { + private final LoadBalancer.SubchannelPicker delegate; + + HealthProducerPicker(LoadBalancer.SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) { + LoadBalancer.PickResult result = delegate.pickSubchannel(args); + LoadBalancer.Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof HealthProducerSubchannel) { + return result.copyWithSubchannel( + ((HealthProducerSubchannel) subchannel).delegate()); + } + return result; + } + } + // The parent subchannel in the health check producer LB chain. It duplicates subchannel state to // both the state listener and health listener. @VisibleForTesting diff --git a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java index 626c2e1104e..acc186e3be6 100644 --- a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java @@ -24,7 +24,9 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Maps; +import com.google.common.primitives.UnsignedInts; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; @@ -37,12 +39,10 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Set; +import java.util.Random; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -55,7 +55,9 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { private static final Logger logger = Logger.getLogger(MultiChildLoadBalancer.class.getName()); - private final Map childLbStates = new LinkedHashMap<>(); + private static final int OFFSET_SEED = new Random().nextInt(); + // Modify by replacing the list to release memory when no longer used. + private List childLbStates = new ArrayList<>(0); private final Helper helper; // Set to true if currently in the process of handling resolved addresses. protected boolean resolvingAddresses; @@ -79,11 +81,13 @@ protected MultiChildLoadBalancer(Helper helper) { /** * Override to utilize parsing of the policy configuration or alternative helper/lb generation. - * Override this if keys are not Endpoints or if child policies have configuration. + * Override this if keys are not Endpoints or if child policies have configuration. Null map + * values preserve the child without delivering the child an update. */ protected Map createChildAddressesMap( ResolvedAddresses resolvedAddresses) { - Map childAddresses = new HashMap<>(); + Map childAddresses = + Maps.newLinkedHashMapWithExpectedSize(resolvedAddresses.getAddresses().size()); for (EquivalentAddressGroup eag : resolvedAddresses.getAddresses()) { ResolvedAddresses addresses = resolvedAddresses.toBuilder() .setAddresses(Collections.singletonList(eag)) @@ -107,21 +111,22 @@ protected ChildLbState createChildLbState(Object key) { */ @Override public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses); try { resolvingAddresses = true; // process resolvedAddresses to update children - AcceptResolvedAddrRetVal acceptRetVal = acceptResolvedAddressesInternal(resolvedAddresses); - if (!acceptRetVal.status.isOk()) { - return acceptRetVal.status; + Map newChildAddresses = createChildAddressesMap(resolvedAddresses); + + // Handle error case + if (newChildAddresses.isEmpty()) { + Status unavailableStatus = Status.UNAVAILABLE.withDescription( + "NameResolver returned no usable address. " + resolvedAddresses); + handleNameResolutionError(unavailableStatus); + return unavailableStatus; } - // Update the picker and our connectivity state - updateOverallBalancingState(); - - // shutdown removed children - shutdownRemoved(acceptRetVal.removedChildren); - return acceptRetVal.status; + return updateChildrenWithResolvedAddresses(newChildAddresses); } finally { resolvingAddresses = false; } @@ -143,70 +148,67 @@ public void handleNameResolutionError(Status error) { @Override public void shutdown() { logger.log(Level.FINE, "Shutdown"); - for (ChildLbState state : childLbStates.values()) { + for (ChildLbState state : childLbStates) { state.shutdown(); } childLbStates.clear(); } - /** - * This does the work to update the child map and calculate which children have been removed. - * You must call {@link #updateOverallBalancingState} to update the picker - * and call {@link #shutdownRemoved(List)} to shutdown the endpoints that have been removed. - */ - protected final AcceptResolvedAddrRetVal acceptResolvedAddressesInternal( - ResolvedAddresses resolvedAddresses) { - logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses); - - Map newChildAddresses = createChildAddressesMap(resolvedAddresses); - - // Handle error case - if (newChildAddresses.isEmpty()) { - Status unavailableStatus = Status.UNAVAILABLE.withDescription( - "NameResolver returned no usable address. " + resolvedAddresses); - handleNameResolutionError(unavailableStatus); - return new AcceptResolvedAddrRetVal(unavailableStatus, null); + private Status updateChildrenWithResolvedAddresses( + Map newChildAddresses) { + // Create a map with the old values + Map oldStatesMap = + Maps.newLinkedHashMapWithExpectedSize(childLbStates.size()); + for (ChildLbState state : childLbStates) { + oldStatesMap.put(state.getKey(), state); } - updateChildrenWithResolvedAddresses(newChildAddresses); - - return new AcceptResolvedAddrRetVal(Status.OK, getRemovedChildren(newChildAddresses.keySet())); - } - - private void updateChildrenWithResolvedAddresses( - Map newChildAddresses) { + // Move ChildLbStates from the map to a new list (preserving the new map's order) + Status status = Status.OK; + List newChildLbStates = new ArrayList<>(newChildAddresses.size()); for (Map.Entry entry : newChildAddresses.entrySet()) { - ChildLbState childLbState = childLbStates.get(entry.getKey()); + ChildLbState childLbState = oldStatesMap.remove(entry.getKey()); if (childLbState == null) { childLbState = createChildLbState(entry.getKey()); - childLbStates.put(entry.getKey(), childLbState); } - childLbState.setResolvedAddresses(entry.getValue()); // update child - childLbState.lb.handleResolvedAddresses(entry.getValue()); // update child LB - } - } - - /** - * Identifies which children have been removed (are not part of the newChildKeys). - */ - private List getRemovedChildren(Set newChildKeys) { - List removedChildren = new ArrayList<>(); - // Do removals - for (Object key : ImmutableList.copyOf(childLbStates.keySet())) { - if (!newChildKeys.contains(key)) { - ChildLbState childLbState = childLbStates.remove(key); - removedChildren.add(childLbState); + newChildLbStates.add(childLbState); + } + // Use a random start position for child updates to weakly "shuffle" connection creation order. + // The network will often add noise to the creation order, but this avoids giving earlier + // children a consistent head start. + for (ChildLbState childLbState : offsetIterable(newChildLbStates, OFFSET_SEED)) { + ResolvedAddresses addresses = newChildAddresses.get(childLbState.getKey()); + if (addresses != null) { + // update child LB + Status newStatus = childLbState.lb.acceptResolvedAddresses(addresses); + if (!newStatus.isOk()) { + status = newStatus; + } } } - return removedChildren; - } - protected final void shutdownRemoved(List removedChildren) { - // Do shutdowns after updating picker to reduce the chance of failing an RPC by picking a - // subchannel that has been shutdown. - for (ChildLbState childLbState : removedChildren) { + childLbStates = newChildLbStates; + // Update the picker and our connectivity state + updateOverallBalancingState(); + + // Remaining entries in map are orphaned + for (ChildLbState childLbState : oldStatesMap.values()) { childLbState.shutdown(); } + return status; + } + + @VisibleForTesting + static Iterable offsetIterable(Collection c, int seed) { + int pos; + if (c.isEmpty()) { + pos = 0; + } else { + pos = UnsignedInts.remainder(seed, c.size()); + } + return Iterables.concat( + Iterables.skip(c, pos), + Iterables.limit(c, pos)); } @Nullable @@ -233,23 +235,7 @@ protected final Helper getHelper() { @VisibleForTesting public final Collection getChildLbStates() { - return childLbStates.values(); - } - - @VisibleForTesting - public final ChildLbState getChildLbState(Object key) { - if (key == null) { - return null; - } - if (key instanceof EquivalentAddressGroup) { - key = new Endpoint((EquivalentAddressGroup) key); - } - return childLbStates.get(key); - } - - @VisibleForTesting - public final ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) { - return getChildLbState(new Endpoint(eag)); + return childLbStates; } /** @@ -278,12 +264,12 @@ protected final List getReadyChildren() { */ public class ChildLbState { private final Object key; - private ResolvedAddresses resolvedAddresses; - private final LoadBalancer lb; private ConnectivityState currentState; private SubchannelPicker currentPicker = new FixedResultPicker(PickResult.withNoResult()); + @SuppressWarnings("this-escape") + // TODO(okshiva): Fix 'this-escape' from the constructor before making the API public. public ChildLbState(Object key, LoadBalancer.Factory policyFactory) { this.key = key; this.lb = policyFactory.newLoadBalancer(createChildHelper()); @@ -337,23 +323,6 @@ protected final void setCurrentPicker(SubchannelPicker newPicker) { currentPicker = newPicker; } - public final EquivalentAddressGroup getEag() { - if (resolvedAddresses == null || resolvedAddresses.getAddresses().isEmpty()) { - return null; - } - return resolvedAddresses.getAddresses().get(0); - } - - protected final void setResolvedAddresses(ResolvedAddresses newAddresses) { - checkNotNull(newAddresses, "Missing address list for child"); - resolvedAddresses = newAddresses; - } - - @VisibleForTesting - public final ResolvedAddresses getResolvedAddresses() { - return resolvedAddresses; - } - /** * ChildLbStateHelper is the glue between ChildLbState and the helpers associated with the * petiole policy above and the PickFirstLoadBalancer's helper below. @@ -442,14 +411,4 @@ public String toString() { return addrs.toString(); } } - - protected static class AcceptResolvedAddrRetVal { - public final Status status; - public final List removedChildren; - - public AcceptResolvedAddrRetVal(Status status, List removedChildren) { - this.status = status; - this.removedChildren = removedChildren; - } - } } diff --git a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java index 1f0290e76d7..dc61441bccd 100644 --- a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java @@ -22,6 +22,7 @@ import static java.util.concurrent.TimeUnit.NANOSECONDS; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Ticker; import com.google.common.collect.ForwardingMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -39,7 +40,6 @@ import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; -import io.grpc.internal.TimeProvider; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Collection; @@ -82,7 +82,7 @@ public final class OutlierDetectionLoadBalancer extends LoadBalancer { private final SynchronizationContext syncContext; private final Helper childHelper; private final GracefulSwitchLoadBalancer switchLb; - private TimeProvider timeProvider; + private Ticker ticker; private final ScheduledExecutorService timeService; private ScheduledHandle detectionTimerHandle; private Long detectionTimerStartNanos; @@ -95,14 +95,14 @@ public final class OutlierDetectionLoadBalancer extends LoadBalancer { /** * Creates a new instance of {@link OutlierDetectionLoadBalancer}. */ - public OutlierDetectionLoadBalancer(Helper helper, TimeProvider timeProvider) { + public OutlierDetectionLoadBalancer(Helper helper, Ticker ticker) { logger = helper.getChannelLogger(); childHelper = new ChildHelper(checkNotNull(helper, "helper")); switchLb = new GracefulSwitchLoadBalancer(childHelper); endpointTrackerMap = new EndpointTrackerMap(); this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); - this.timeProvider = timeProvider; + this.ticker = ticker; logger.log(ChannelLogLevel.DEBUG, "OutlierDetection lb created."); } @@ -142,7 +142,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { // If outlier detection is actually configured, start a timer that will periodically try to // detect outliers. if (config.outlierDetectionEnabled()) { - Long initialDelayNanos; + long initialDelayNanos; if (detectionTimerStartNanos == null) { // On the first go we use the configured interval. @@ -151,7 +151,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { // If a timer has started earlier we cancel it and use the difference between the start // time and now as the interval. initialDelayNanos = Math.max(0L, - config.intervalNanos - (timeProvider.currentTimeNanos() - detectionTimerStartNanos)); + config.intervalNanos - (ticker.read() - detectionTimerStartNanos)); } // If a timer has been previously created we need to cancel it and reset all the call counters @@ -171,9 +171,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { endpointTrackerMap.cancelTracking(); } - switchLb.handleResolvedAddresses( + return switchLb.acceptResolvedAddresses( resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(config.childConfig).build()); - return Status.OK; } @Override @@ -190,7 +189,7 @@ public void shutdown() { * This timer will be invoked periodically, according to configuration, and it will look for any * outlier subchannels. */ - class DetectionTimer implements Runnable { + final class DetectionTimer implements Runnable { OutlierDetectionLoadBalancerConfig config; ChannelLogger logger; @@ -202,7 +201,7 @@ class DetectionTimer implements Runnable { @Override public void run() { - detectionTimerStartNanos = timeProvider.currentTimeNanos(); + detectionTimerStartNanos = ticker.read(); endpointTrackerMap.swapCounters(); @@ -218,7 +217,7 @@ public void run() { * This child helper wraps the provided helper so that it can hand out wrapped {@link * OutlierDetectionSubchannel}s and manage the address info map. */ - class ChildHelper extends ForwardingLoadBalancerHelper { + final class ChildHelper extends ForwardingLoadBalancerHelper { private Helper delegate; @@ -260,7 +259,7 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne } } - class OutlierDetectionSubchannel extends ForwardingSubchannel { + final class OutlierDetectionSubchannel extends ForwardingSubchannel { private final Subchannel delegate; private EndpointTracker endpointTracker; @@ -399,7 +398,7 @@ protected Subchannel delegate() { /** * Wraps the actual listener so that state changes from the actual one can be intercepted. */ - class OutlierDetectionSubchannelStateListener implements SubchannelStateListener { + final class OutlierDetectionSubchannelStateListener implements SubchannelStateListener { private final SubchannelStateListener delegate; @@ -429,7 +428,7 @@ public String toString() { * This picker delegates the actual picking logic to a wrapped delegate, but associates a {@link * ClientStreamTracer} with each pick to track the results of each subchannel stream. */ - class OutlierDetectionPicker extends SubchannelPicker { + final class OutlierDetectionPicker extends SubchannelPicker { private final SubchannelPicker delegate; @@ -443,9 +442,14 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { Subchannel subchannel = pickResult.getSubchannel(); if (subchannel != null) { - return PickResult.withSubchannel(subchannel, new ResultCountingClientStreamTracerFactory( - subchannel.getAttributes().get(ENDPOINT_TRACKER_KEY), - pickResult.getStreamTracerFactory())); + EndpointTracker tracker = subchannel.getAttributes().get(ENDPOINT_TRACKER_KEY); + if (subchannel instanceof OutlierDetectionSubchannel) { + subchannel = ((OutlierDetectionSubchannel) subchannel).delegate(); + } + return pickResult.copyWithSubchannel(subchannel) + .copyWithStreamTracerFactory(new ResultCountingClientStreamTracerFactory( + tracker, + pickResult.getStreamTracerFactory())); } return pickResult; @@ -455,7 +459,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { * Builds instances of a {@link ClientStreamTracer} that increments the call count in the * tracker for each closed stream. */ - class ResultCountingClientStreamTracerFactory extends ClientStreamTracer.Factory { + final class ResultCountingClientStreamTracerFactory extends ClientStreamTracer.Factory { private final EndpointTracker tracker; @@ -499,7 +503,7 @@ public void streamClosed(Status status) { /** * Tracks additional information about the endpoint needed for outlier detection. */ - static class EndpointTracker { + static final class EndpointTracker { private OutlierDetectionLoadBalancerConfig config; // Marked as volatile to assure that when the inactive counter is swapped in as the new active @@ -639,11 +643,11 @@ public boolean maxEjectionTimeElapsed(long currentTimeNanos) { config.baseEjectionTimeNanos * ejectionTimeMultiplier, maxEjectionDurationSecs); - return currentTimeNanos > maxEjectionTimeNanos; + return currentTimeNanos - maxEjectionTimeNanos > 0; } /** Tracks both successful and failed call counts. */ - private static class CallCounter { + private static final class CallCounter { AtomicLong successCount = new AtomicLong(); AtomicLong failureCount = new AtomicLong(); @@ -664,7 +668,7 @@ public String toString() { /** * Maintains a mapping from endpoint (a set of addresses) to their trackers. */ - static class EndpointTrackerMap extends ForwardingMap, EndpointTracker> { + static final class EndpointTrackerMap extends ForwardingMap, EndpointTracker> { private final Map, EndpointTracker> trackerMap; EndpointTrackerMap() { @@ -685,7 +689,11 @@ void updateTrackerConfigs(OutlierDetectionLoadBalancerConfig config) { /** Adds a new tracker for every given address. */ void putNewTrackers(OutlierDetectionLoadBalancerConfig config, Set> endpoints) { - endpoints.forEach(e -> trackerMap.putIfAbsent(e, new EndpointTracker(config))); + for (Set endpoint : endpoints) { + if (!trackerMap.containsKey(endpoint)) { + trackerMap.put(endpoint, new EndpointTracker(config)); + } + } } /** Resets the call counters for all the trackers in the map. */ @@ -720,7 +728,7 @@ void swapCounters() { * that don't have ejected subchannels and uneject ones that have spent the maximum ejection * time allowed. */ - void maybeUnejectOutliers(Long detectionTimerStartNanos) { + void maybeUnejectOutliers(long detectionTimerStartNanos) { for (EndpointTracker tracker : trackerMap.values()) { if (!tracker.subchannelsEjected()) { tracker.decrementEjectionTimeMultiplier(); @@ -781,7 +789,7 @@ static List forConfig(OutlierDetectionLoadBalancerConf * required rate is not fixed, but is based on the mean and standard deviation of the success * rates of all of the addresses. */ - static class SuccessRateOutlierEjectionAlgorithm implements OutlierEjectionAlgorithm { + static final class SuccessRateOutlierEjectionAlgorithm implements OutlierEjectionAlgorithm { private final OutlierDetectionLoadBalancerConfig config; @@ -866,7 +874,7 @@ static double standardDeviation(Collection values, double mean) { } } - static class FailurePercentageOutlierEjectionAlgorithm implements OutlierEjectionAlgorithm { + static final class FailurePercentageOutlierEjectionAlgorithm implements OutlierEjectionAlgorithm { private final OutlierDetectionLoadBalancerConfig config; @@ -948,64 +956,54 @@ private static boolean hasSingleAddress(List addressGrou */ public static final class OutlierDetectionLoadBalancerConfig { - public final Long intervalNanos; - public final Long baseEjectionTimeNanos; - public final Long maxEjectionTimeNanos; - public final Integer maxEjectionPercent; + public final long intervalNanos; + public final long baseEjectionTimeNanos; + public final long maxEjectionTimeNanos; + public final int maxEjectionPercent; public final SuccessRateEjection successRateEjection; public final FailurePercentageEjection failurePercentageEjection; public final Object childConfig; - private OutlierDetectionLoadBalancerConfig(Long intervalNanos, - Long baseEjectionTimeNanos, - Long maxEjectionTimeNanos, - Integer maxEjectionPercent, - SuccessRateEjection successRateEjection, - FailurePercentageEjection failurePercentageEjection, - Object childConfig) { - this.intervalNanos = intervalNanos; - this.baseEjectionTimeNanos = baseEjectionTimeNanos; - this.maxEjectionTimeNanos = maxEjectionTimeNanos; - this.maxEjectionPercent = maxEjectionPercent; - this.successRateEjection = successRateEjection; - this.failurePercentageEjection = failurePercentageEjection; - this.childConfig = childConfig; + private OutlierDetectionLoadBalancerConfig(Builder builder) { + this.intervalNanos = builder.intervalNanos; + this.baseEjectionTimeNanos = builder.baseEjectionTimeNanos; + this.maxEjectionTimeNanos = builder.maxEjectionTimeNanos; + this.maxEjectionPercent = builder.maxEjectionPercent; + this.successRateEjection = builder.successRateEjection; + this.failurePercentageEjection = builder.failurePercentageEjection; + this.childConfig = builder.childConfig; } /** Builds a new {@link OutlierDetectionLoadBalancerConfig}. */ - public static class Builder { - Long intervalNanos = 10_000_000_000L; // 10s - Long baseEjectionTimeNanos = 30_000_000_000L; // 30s - Long maxEjectionTimeNanos = 300_000_000_000L; // 300s - Integer maxEjectionPercent = 10; + public static final class Builder { + long intervalNanos = 10_000_000_000L; // 10s + long baseEjectionTimeNanos = 30_000_000_000L; // 30s + long maxEjectionTimeNanos = 300_000_000_000L; // 300s + int maxEjectionPercent = 10; SuccessRateEjection successRateEjection; FailurePercentageEjection failurePercentageEjection; Object childConfig; /** The interval between outlier detection sweeps. */ - public Builder setIntervalNanos(Long intervalNanos) { - checkArgument(intervalNanos != null); + public Builder setIntervalNanos(long intervalNanos) { this.intervalNanos = intervalNanos; return this; } /** The base time an address is ejected for. */ - public Builder setBaseEjectionTimeNanos(Long baseEjectionTimeNanos) { - checkArgument(baseEjectionTimeNanos != null); + public Builder setBaseEjectionTimeNanos(long baseEjectionTimeNanos) { this.baseEjectionTimeNanos = baseEjectionTimeNanos; return this; } /** The longest time an address can be ejected. */ - public Builder setMaxEjectionTimeNanos(Long maxEjectionTimeNanos) { - checkArgument(maxEjectionTimeNanos != null); + public Builder setMaxEjectionTimeNanos(long maxEjectionTimeNanos) { this.maxEjectionTimeNanos = maxEjectionTimeNanos; return this; } /** The algorithm agnostic maximum percentage of addresses that can be ejected. */ - public Builder setMaxEjectionPercent(Integer maxEjectionPercent) { - checkArgument(maxEjectionPercent != null); + public Builder setMaxEjectionPercent(int maxEjectionPercent) { this.maxEjectionPercent = maxEjectionPercent; return this; } @@ -1037,64 +1035,57 @@ public Builder setChildConfig(Object childConfig) { /** Builds a new instance of {@link OutlierDetectionLoadBalancerConfig}. */ public OutlierDetectionLoadBalancerConfig build() { checkState(childConfig != null); - return new OutlierDetectionLoadBalancerConfig(intervalNanos, baseEjectionTimeNanos, - maxEjectionTimeNanos, maxEjectionPercent, successRateEjection, - failurePercentageEjection, childConfig); + return new OutlierDetectionLoadBalancerConfig(this); } } /** The configuration for success rate ejection. */ - public static class SuccessRateEjection { - - public final Integer stdevFactor; - public final Integer enforcementPercentage; - public final Integer minimumHosts; - public final Integer requestVolume; - - SuccessRateEjection(Integer stdevFactor, Integer enforcementPercentage, Integer minimumHosts, - Integer requestVolume) { - this.stdevFactor = stdevFactor; - this.enforcementPercentage = enforcementPercentage; - this.minimumHosts = minimumHosts; - this.requestVolume = requestVolume; + public static final class SuccessRateEjection { + + public final int stdevFactor; + public final int enforcementPercentage; + public final int minimumHosts; + public final int requestVolume; + + SuccessRateEjection(Builder builder) { + this.stdevFactor = builder.stdevFactor; + this.enforcementPercentage = builder.enforcementPercentage; + this.minimumHosts = builder.minimumHosts; + this.requestVolume = builder.requestVolume; } /** Builds new instances of {@link SuccessRateEjection}. */ public static final class Builder { - Integer stdevFactor = 1900; - Integer enforcementPercentage = 100; - Integer minimumHosts = 5; - Integer requestVolume = 100; + int stdevFactor = 1900; + int enforcementPercentage = 100; + int minimumHosts = 5; + int requestVolume = 100; /** The product of this and the standard deviation of success rates determine the ejection * threshold. */ - public Builder setStdevFactor(Integer stdevFactor) { - checkArgument(stdevFactor != null); + public Builder setStdevFactor(int stdevFactor) { this.stdevFactor = stdevFactor; return this; } /** Only eject this percentage of outliers. */ - public Builder setEnforcementPercentage(Integer enforcementPercentage) { - checkArgument(enforcementPercentage != null); + public Builder setEnforcementPercentage(int enforcementPercentage) { checkArgument(enforcementPercentage >= 0 && enforcementPercentage <= 100); this.enforcementPercentage = enforcementPercentage; return this; } /** The minimum amount of hosts needed for success rate ejection. */ - public Builder setMinimumHosts(Integer minimumHosts) { - checkArgument(minimumHosts != null); + public Builder setMinimumHosts(int minimumHosts) { checkArgument(minimumHosts >= 0); this.minimumHosts = minimumHosts; return this; } /** The minimum address request volume to be considered for success rate ejection. */ - public Builder setRequestVolume(Integer requestVolume) { - checkArgument(requestVolume != null); + public Builder setRequestVolume(int requestVolume) { checkArgument(requestVolume >= 0); this.requestVolume = requestVolume; return this; @@ -1102,53 +1093,48 @@ public Builder setRequestVolume(Integer requestVolume) { /** Builds a new instance of {@link SuccessRateEjection}. */ public SuccessRateEjection build() { - return new SuccessRateEjection(stdevFactor, enforcementPercentage, minimumHosts, - requestVolume); + return new SuccessRateEjection(this); } } } /** The configuration for failure percentage ejection. */ - public static class FailurePercentageEjection { - public final Integer threshold; - public final Integer enforcementPercentage; - public final Integer minimumHosts; - public final Integer requestVolume; - - FailurePercentageEjection(Integer threshold, Integer enforcementPercentage, - Integer minimumHosts, Integer requestVolume) { - this.threshold = threshold; - this.enforcementPercentage = enforcementPercentage; - this.minimumHosts = minimumHosts; - this.requestVolume = requestVolume; + public static final class FailurePercentageEjection { + public final int threshold; + public final int enforcementPercentage; + public final int minimumHosts; + public final int requestVolume; + + FailurePercentageEjection(Builder builder) { + this.threshold = builder.threshold; + this.enforcementPercentage = builder.enforcementPercentage; + this.minimumHosts = builder.minimumHosts; + this.requestVolume = builder.requestVolume; } /** For building new {@link FailurePercentageEjection} instances. */ public static class Builder { - Integer threshold = 85; - Integer enforcementPercentage = 100; - Integer minimumHosts = 5; - Integer requestVolume = 50; + int threshold = 85; + int enforcementPercentage = 100; + int minimumHosts = 5; + int requestVolume = 50; /** The failure percentage that will result in an address being considered an outlier. */ - public Builder setThreshold(Integer threshold) { - checkArgument(threshold != null); + public Builder setThreshold(int threshold) { checkArgument(threshold >= 0 && threshold <= 100); this.threshold = threshold; return this; } /** Only eject this percentage of outliers. */ - public Builder setEnforcementPercentage(Integer enforcementPercentage) { - checkArgument(enforcementPercentage != null); + public Builder setEnforcementPercentage(int enforcementPercentage) { checkArgument(enforcementPercentage >= 0 && enforcementPercentage <= 100); this.enforcementPercentage = enforcementPercentage; return this; } /** The minimum amount of host for failure percentage ejection to be enabled. */ - public Builder setMinimumHosts(Integer minimumHosts) { - checkArgument(minimumHosts != null); + public Builder setMinimumHosts(int minimumHosts) { checkArgument(minimumHosts >= 0); this.minimumHosts = minimumHosts; return this; @@ -1158,8 +1144,7 @@ public Builder setMinimumHosts(Integer minimumHosts) { * The request volume required for an address to be considered for failure percentage * ejection. */ - public Builder setRequestVolume(Integer requestVolume) { - checkArgument(requestVolume != null); + public Builder setRequestVolume(int requestVolume) { checkArgument(requestVolume >= 0); this.requestVolume = requestVolume; return this; @@ -1167,8 +1152,7 @@ public Builder setRequestVolume(Integer requestVolume) { /** Builds a new instance of {@link FailurePercentageEjection}. */ public FailurePercentageEjection build() { - return new FailurePercentageEjection(threshold, enforcementPercentage, minimumHosts, - requestVolume); + return new FailurePercentageEjection(this); } } } diff --git a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancerProvider.java b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancerProvider.java index b35e1144581..084898bc38f 100644 --- a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancerProvider.java +++ b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancerProvider.java @@ -16,14 +16,15 @@ package io.grpc.util; +import com.google.common.base.Ticker; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; -import io.grpc.internal.TimeProvider; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig.FailurePercentageEjection; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig.SuccessRateEjection; @@ -34,7 +35,7 @@ public final class OutlierDetectionLoadBalancerProvider extends LoadBalancerProv @Override public LoadBalancer newLoadBalancer(Helper helper) { - return new OutlierDetectionLoadBalancer(helper, TimeProvider.SYSTEM_TIME_PROVIDER); + return new OutlierDetectionLoadBalancer(helper, Ticker.systemTicker()); } @Override @@ -148,9 +149,10 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map rawC ConfigOrError childConfig = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( JsonUtil.getListOfObjects(rawConfig, "childPolicy")); if (childConfig.getError() != null) { - return ConfigOrError.fromError(Status.INTERNAL - .withDescription("Failed to parse child in outlier_detection_experimental: " + rawConfig) - .withCause(childConfig.getError().asRuntimeException())); + return ConfigOrError.fromError(GrpcUtil.statusWithDetails( + Status.Code.UNAVAILABLE, + "Failed to parse child in outlier_detection_experimental", + childConfig.getError())); } configBuilder.setChildConfig(childConfig.getConfig()); diff --git a/util/src/main/java/io/grpc/util/RandomSubsettingLoadBalancer.java b/util/src/main/java/io/grpc/util/RandomSubsettingLoadBalancer.java new file mode 100644 index 00000000000..ad4de9e8921 --- /dev/null +++ b/util/src/main/java/io/grpc/util/RandomSubsettingLoadBalancer.java @@ -0,0 +1,161 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.hash.HashCode; +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; +import com.google.common.primitives.Ints; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.Status; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.Random; + + +/** + * Wraps a child {@code LoadBalancer}, separating the total set of backends into smaller subsets for + * the child balancer to balance across. + * + *

This implements random subsetting gRFC: + * https://https://github.com/grpc/proposal/blob/master/A68-random-subsetting.md + */ +final class RandomSubsettingLoadBalancer extends LoadBalancer { + private final GracefulSwitchLoadBalancer switchLb; + private final HashFunction hashFunc; + + public RandomSubsettingLoadBalancer(Helper helper) { + this(helper, new Random().nextInt()); + } + + @VisibleForTesting + RandomSubsettingLoadBalancer(Helper helper, int seed) { + switchLb = new GracefulSwitchLoadBalancer(checkNotNull(helper, "helper")); + hashFunc = Hashing.murmur3_128(seed); + } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + RandomSubsettingLoadBalancerConfig config = + (RandomSubsettingLoadBalancerConfig) + resolvedAddresses.getLoadBalancingPolicyConfig(); + + ResolvedAddresses subsetAddresses = filterEndpoints(resolvedAddresses, config.subsetSize); + + return switchLb.acceptResolvedAddresses( + subsetAddresses.toBuilder() + .setLoadBalancingPolicyConfig(config.childConfig) + .build()); + } + + // implements the subsetting algorithm, as described in A68: + // https://github.com/grpc/proposal/pull/423 + private ResolvedAddresses filterEndpoints(ResolvedAddresses resolvedAddresses, int subsetSize) { + if (subsetSize >= resolvedAddresses.getAddresses().size()) { + return resolvedAddresses; + } + + ArrayList endpointWithHashList = + new ArrayList<>(resolvedAddresses.getAddresses().size()); + + for (EquivalentAddressGroup addressGroup : resolvedAddresses.getAddresses()) { + HashCode hashCode = hashFunc.hashString( + addressGroup.getAddresses().get(0).toString(), + StandardCharsets.UTF_8); + endpointWithHashList.add(new EndpointWithHash(addressGroup, hashCode.asLong())); + } + + Collections.sort(endpointWithHashList, new HashAddressComparator()); + + ArrayList addressGroups = new ArrayList<>(subsetSize); + + for (int idx = 0; idx < subsetSize; ++idx) { + addressGroups.add(endpointWithHashList.get(idx).addressGroup); + } + + return resolvedAddresses.toBuilder().setAddresses(addressGroups).build(); + } + + @Override + public void handleNameResolutionError(Status error) { + switchLb.handleNameResolutionError(error); + } + + @Override + public void shutdown() { + switchLb.shutdown(); + } + + private static final class EndpointWithHash { + public final EquivalentAddressGroup addressGroup; + public final long hashCode; + + public EndpointWithHash(EquivalentAddressGroup addressGroup, long hashCode) { + this.addressGroup = addressGroup; + this.hashCode = hashCode; + } + } + + private static final class HashAddressComparator implements Comparator { + @Override + public int compare(EndpointWithHash lhs, EndpointWithHash rhs) { + return Long.compare(lhs.hashCode, rhs.hashCode); + } + } + + public static final class RandomSubsettingLoadBalancerConfig { + public final int subsetSize; + public final Object childConfig; + + private RandomSubsettingLoadBalancerConfig(int subsetSize, Object childConfig) { + this.subsetSize = subsetSize; + this.childConfig = childConfig; + } + + public static class Builder { + int subsetSize; + Object childConfig; + + public Builder setSubsetSize(long subsetSize) { + checkArgument(subsetSize > 0L, "Subset size must be greater than 0"); + // clamping subset size to Integer.MAX_VALUE due to collection indexing limitations in JVM + this.subsetSize = Ints.saturatedCast(subsetSize); + return this; + } + + public Builder setChildConfig(Object childConfig) { + this.childConfig = checkNotNull(childConfig, "childConfig"); + return this; + } + + public RandomSubsettingLoadBalancerConfig build() { + checkState(subsetSize != 0L, "Subset size must be set before building the config"); + return new RandomSubsettingLoadBalancerConfig( + subsetSize, + checkNotNull(childConfig, "childConfig")); + } + } + } +} diff --git a/util/src/main/java/io/grpc/util/RandomSubsettingLoadBalancerProvider.java b/util/src/main/java/io/grpc/util/RandomSubsettingLoadBalancerProvider.java new file mode 100644 index 00000000000..edcbf48a201 --- /dev/null +++ b/util/src/main/java/io/grpc/util/RandomSubsettingLoadBalancerProvider.java @@ -0,0 +1,86 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import io.grpc.Internal; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status; +import io.grpc.internal.JsonUtil; +import java.util.Map; + +@Internal +public final class RandomSubsettingLoadBalancerProvider extends LoadBalancerProvider { + private static final String POLICY_NAME = "random_subsetting_experimental"; + + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return new RandomSubsettingLoadBalancer(helper); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return POLICY_NAME; + } + + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { + try { + return parseLoadBalancingPolicyConfigInternal(rawConfig); + } catch (RuntimeException e) { + return ConfigOrError.fromError( + Status.UNAVAILABLE + .withCause(e) + .withDescription("Failed parsing configuration for " + getPolicyName())); + } + } + + private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map rawConfig) { + Long subsetSize = JsonUtil.getNumberAsLong(rawConfig, "subsetSize"); + if (subsetSize == null) { + return ConfigOrError.fromError( + Status.UNAVAILABLE.withDescription( + "Subset size missing in " + getPolicyName() + ", LB policy config=" + rawConfig)); + } + + ConfigOrError childConfig = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( + JsonUtil.getListOfObjects(rawConfig, "childPolicy")); + if (childConfig.getError() != null) { + return ConfigOrError.fromError(Status.UNAVAILABLE + .withDescription( + "Failed to parse child in " + getPolicyName() + ", LB policy config=" + rawConfig) + .withCause(childConfig.getError().asRuntimeException())); + } + + return ConfigOrError.fromConfig( + new RandomSubsettingLoadBalancer.RandomSubsettingLoadBalancerConfig.Builder() + .setSubsetSize(subsetSize) + .setChildConfig(childConfig.getConfig()) + .build()); + } +} diff --git a/util/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider b/util/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider index 1fdd69cb00b..d973a6f6728 100644 --- a/util/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider +++ b/util/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider @@ -1,2 +1,3 @@ io.grpc.util.SecretRoundRobinLoadBalancerProvider$Provider io.grpc.util.OutlierDetectionLoadBalancerProvider +io.grpc.util.RandomSubsettingLoadBalancerProvider diff --git a/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java b/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java index f96c85e4f4f..b8431d4f991 100644 --- a/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java +++ b/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java @@ -18,6 +18,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -48,7 +49,6 @@ public class AdvancedTlsX509KeyManagerTest { private static final String SERVER_0_PEM_FILE = "server0.pem"; private static final String CLIENT_0_KEY_FILE = "client.key"; private static final String CLIENT_0_PEM_FILE = "client.pem"; - private static final String ALIAS = "default"; private ScheduledExecutorService executor; @@ -79,22 +79,62 @@ public void setUp() throws Exception { public void updateTrustCredentials_replacesIssuers() throws Exception { // Overall happy path checking of public API. AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0); - assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS)); - assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS)); + String alias1 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(AdvancedTlsX509KeyManager.ALIAS_PREFIX + "1", alias1); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias1)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias1)); serverKeyManager.updateIdentityCredentials(clientCert0File, clientKey0File); - assertEquals(clientKey0, serverKeyManager.getPrivateKey(ALIAS)); - assertArrayEquals(clientCert0, serverKeyManager.getCertificateChain(ALIAS)); - - serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File,1, + String alias2 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(AdvancedTlsX509KeyManager.ALIAS_PREFIX + "2", alias2); + assertEquals(clientKey0, serverKeyManager.getPrivateKey(alias2)); + assertArrayEquals(clientCert0, serverKeyManager.getCertificateChain(alias2)); + // Previous alias still resolves — retained to allow in-progress handshakes to complete. + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias1)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias1)); + + serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File, 1, TimeUnit.MINUTES, executor); - assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS)); - assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS)); + String alias3 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias3)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias3)); + // alias1 is now two rotations back — no longer retained. + assertNull(serverKeyManager.getPrivateKey(alias1)); serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0); - assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS)); - assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS)); + String alias4 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias4)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias4)); + } + + @Test + public void allAliasMethods_returnNullBeforeCredentialsLoaded() { + AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); + + assertNull(keyManager.chooseClientAlias(null, null, null)); + assertNull(keyManager.chooseServerAlias(null, null, null)); + assertNull(keyManager.chooseEngineClientAlias(null, null, null)); + assertNull(keyManager.chooseEngineServerAlias(null, null, null)); + assertNull(keyManager.getClientAliases(null, null)); + assertNull(keyManager.getServerAliases(null, null)); + assertNull(keyManager.getPrivateKey("key-1")); + assertNull(keyManager.getCertificateChain("key-1")); + } + + @Test + public void allAliasMethods_agreeAfterCredentialLoad() throws Exception { + AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); + keyManager.updateIdentityCredentials(serverCert0, serverKey0); + + String expectedAlias = AdvancedTlsX509KeyManager.ALIAS_PREFIX + "1"; + assertEquals(expectedAlias, keyManager.chooseClientAlias(null, null, null)); + assertEquals(expectedAlias, keyManager.chooseServerAlias(null, null, null)); + assertEquals(expectedAlias, keyManager.chooseEngineClientAlias(null, null, null)); + assertEquals(expectedAlias, keyManager.chooseEngineServerAlias(null, null, null)); + assertArrayEquals(new String[]{expectedAlias}, keyManager.getClientAliases(null, null)); + assertArrayEquals(new String[]{expectedAlias}, keyManager.getServerAliases(null, null)); } @Test diff --git a/util/src/test/java/io/grpc/util/AdvancedTlsX509TrustManagerTest.java b/util/src/test/java/io/grpc/util/AdvancedTlsX509TrustManagerTest.java index 36ef75abeaa..b9803b03570 100644 --- a/util/src/test/java/io/grpc/util/AdvancedTlsX509TrustManagerTest.java +++ b/util/src/test/java/io/grpc/util/AdvancedTlsX509TrustManagerTest.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.when; import com.google.common.collect.Iterables; +import com.google.common.io.Files; import io.grpc.internal.FakeClock; import io.grpc.internal.testing.TestUtils; import io.grpc.testing.TlsTesting; @@ -44,6 +45,7 @@ import java.util.logging.LogRecord; import java.util.logging.Logger; import javax.net.ssl.SSLSocket; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -51,25 +53,33 @@ /** Unit tests for {@link AdvancedTlsX509TrustManager}. */ @RunWith(JUnit4.class) +@IgnoreJRERequirement public class AdvancedTlsX509TrustManagerTest { private static final String CA_PEM_FILE = "ca.pem"; private static final String SERVER_0_PEM_FILE = "server0.pem"; + private static final String SERVER_1_PEM_FILE = "server1.pem"; private File caCertFile; private File serverCert0File; + private File serverCert1File; private X509Certificate[] caCert; private X509Certificate[] serverCert0; + private X509Certificate[] serverCert1; + private FakeClock fakeClock; private ScheduledExecutorService executor; @Before public void setUp() throws IOException, GeneralSecurityException { - executor = new FakeClock().getScheduledExecutorService(); + fakeClock = new FakeClock(); + executor = fakeClock.getScheduledExecutorService(); caCertFile = TestUtils.loadCert(CA_PEM_FILE); caCert = CertificateUtils.getX509Certificates(TlsTesting.loadCert(CA_PEM_FILE)); serverCert0File = TestUtils.loadCert(SERVER_0_PEM_FILE); serverCert0 = CertificateUtils.getX509Certificates(TlsTesting.loadCert(SERVER_0_PEM_FILE)); + serverCert1File = TestUtils.loadCert(SERVER_1_PEM_FILE); + serverCert1 = CertificateUtils.getX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); } @Test @@ -132,6 +142,17 @@ record -> record.getMessage().contains("Default value of ")); } } + @Test + public void missingFile_throwsFileNotFoundException() throws Exception { + AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build(); + File nonExistentFile = new File("missing_cert.pem"); + Exception thrown = + assertThrows(Exception.class, () -> trustManager.updateTrustCredentials(nonExistentFile)); + assertNotNull(thrown); + assertEquals(thrown.getMessage(), + "Certificate file not found or not readable: " + nonExistentFile.getAbsolutePath()); + } + @Test public void clientTrustedWithSocketTest() throws Exception { AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder() @@ -145,6 +166,39 @@ public void clientTrustedWithSocketTest() throws Exception { assertEquals("No handshake session", ce.getMessage()); } + @Test + public void updateTrustCredentials_rotate() throws GeneralSecurityException, IOException { + AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build(); + trustManager.updateTrustCredentials(serverCert0File); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + + trustManager.updateTrustCredentials(serverCert0File, 1, TimeUnit.MINUTES, + executor); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + + fakeClock.forwardTime(1, TimeUnit.MINUTES); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + + serverCert0File.setLastModified(serverCert0File.lastModified() - 2000); + + fakeClock.forwardTime(1, TimeUnit.MINUTES); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + + long beforeModify = serverCert0File.lastModified(); + Files.copy(serverCert1File, serverCert0File); + serverCert0File.setLastModified(beforeModify); + + // although file content changed, file modification time is not changed + fakeClock.forwardTime(1, TimeUnit.MINUTES); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + + serverCert0File.setLastModified(beforeModify + 2000); + + // file modification time changed + fakeClock.forwardTime(1, TimeUnit.MINUTES); + assertArrayEquals(serverCert1, trustManager.getAcceptedIssuers()); + } + private static class TestHandler extends Handler { private final List records = new ArrayList<>(); diff --git a/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java b/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java index f31443ace7b..0467f9526f6 100644 --- a/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java @@ -18,9 +18,10 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.ConnectivityState.CONNECTING; +import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.util.GracefulSwitchLoadBalancer.BUFFER_PICKER; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.inOrder; @@ -32,6 +33,7 @@ import static org.mockito.Mockito.when; import com.google.common.testing.EqualsTester; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; @@ -52,9 +54,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -67,10 +67,6 @@ public class GracefulSwitchLoadBalancerTest { private static final Object FAKE_CONFIG = new Object(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - private final Map balancers = new HashMap<>(); private final Map helpers = new HashMap<>(); private final Helper mockHelper = mock(Helper.class); @@ -83,453 +79,6 @@ public class GracefulSwitchLoadBalancerTest { new FakeLoadBalancerProvider("lb_policy_3"), }; - // OLD TESTS - - @Test - @Deprecated - public void switchTo_canHandleEmptyAddressListFromNameResolutionForwardedToLatestPolicy() { - gracefulSwitchLb.switchTo(lbPolicies[0]); - LoadBalancer lb0 = balancers.get(lbPolicies[0]); - Helper helper0 = helpers.get(lb0); - SubchannelPicker picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(READY, picker); - - assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isFalse(); - when(lb0.canHandleEmptyAddressListFromNameResolution()).thenReturn(true); - assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isTrue(); - - gracefulSwitchLb.switchTo(lbPolicies[1]); - LoadBalancer lb1 = balancers.get(lbPolicies[1]); - - assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isFalse(); - - when(lb1.canHandleEmptyAddressListFromNameResolution()).thenReturn(true); - assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isTrue(); - - gracefulSwitchLb.switchTo(lbPolicies[2]); - LoadBalancer lb2 = balancers.get(lbPolicies[2]); - - assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isFalse(); - - when(lb2.canHandleEmptyAddressListFromNameResolution()).thenReturn(true); - assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isTrue(); - } - - @Test - @Deprecated - public void switchTo_handleResolvedAddressesAndNameResolutionErrorForwardedToLatestPolicy() { - gracefulSwitchLb.switchTo(lbPolicies[0]); - LoadBalancer lb0 = balancers.get(lbPolicies[0]); - Helper helper0 = helpers.get(lb0); - SubchannelPicker picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(READY, picker); - - ResolvedAddresses addresses = newFakeAddresses(); - gracefulSwitchLb.handleResolvedAddresses(addresses); - verify(lb0).handleResolvedAddresses(addresses); - gracefulSwitchLb.handleNameResolutionError(Status.DATA_LOSS); - verify(lb0).handleNameResolutionError(Status.DATA_LOSS); - - gracefulSwitchLb.switchTo(lbPolicies[1]); - LoadBalancer lb1 = balancers.get(lbPolicies[1]); - addresses = newFakeAddresses(); - gracefulSwitchLb.handleResolvedAddresses(addresses); - verify(lb0, never()).handleResolvedAddresses(addresses); - verify(lb1).handleResolvedAddresses(addresses); - gracefulSwitchLb.handleNameResolutionError(Status.ALREADY_EXISTS); - verify(lb0, never()).handleNameResolutionError(Status.ALREADY_EXISTS); - verify(lb1).handleNameResolutionError(Status.ALREADY_EXISTS); - - gracefulSwitchLb.switchTo(lbPolicies[2]); - verify(lb1).shutdown(); - LoadBalancer lb2 = balancers.get(lbPolicies[2]); - addresses = newFakeAddresses(); - gracefulSwitchLb.handleResolvedAddresses(addresses); - verify(lb0, never()).handleResolvedAddresses(addresses); - verify(lb1, never()).handleResolvedAddresses(addresses); - verify(lb2).handleResolvedAddresses(addresses); - gracefulSwitchLb.handleNameResolutionError(Status.CANCELLED); - verify(lb0, never()).handleNameResolutionError(Status.CANCELLED); - verify(lb1, never()).handleNameResolutionError(Status.CANCELLED); - verify(lb2).handleNameResolutionError(Status.CANCELLED); - - verifyNoMoreInteractions(lb0, lb1, lb2); - } - - @Test - @Deprecated - public void switchTo_acceptResolvedAddressesAndNameResolutionErrorForwardedToLatestPolicy() { - gracefulSwitchLb.switchTo(lbPolicies[0]); - LoadBalancer lb0 = balancers.get(lbPolicies[0]); - Helper helper0 = helpers.get(lb0); - SubchannelPicker picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(READY, picker); - - ResolvedAddresses addresses = newFakeAddresses(); - gracefulSwitchLb.acceptResolvedAddresses(addresses); - verify(lb0).acceptResolvedAddresses(addresses); - gracefulSwitchLb.handleNameResolutionError(Status.DATA_LOSS); - verify(lb0).handleNameResolutionError(Status.DATA_LOSS); - - gracefulSwitchLb.switchTo(lbPolicies[1]); - LoadBalancer lb1 = balancers.get(lbPolicies[1]); - addresses = newFakeAddresses(); - gracefulSwitchLb.acceptResolvedAddresses(addresses); - verify(lb0, never()).acceptResolvedAddresses(addresses); - verify(lb1).acceptResolvedAddresses(addresses); - gracefulSwitchLb.handleNameResolutionError(Status.ALREADY_EXISTS); - verify(lb0, never()).handleNameResolutionError(Status.ALREADY_EXISTS); - verify(lb1).handleNameResolutionError(Status.ALREADY_EXISTS); - - gracefulSwitchLb.switchTo(lbPolicies[2]); - verify(lb1).shutdown(); - LoadBalancer lb2 = balancers.get(lbPolicies[2]); - addresses = newFakeAddresses(); - gracefulSwitchLb.acceptResolvedAddresses(addresses); - verify(lb0, never()).acceptResolvedAddresses(addresses); - verify(lb1, never()).acceptResolvedAddresses(addresses); - verify(lb2).acceptResolvedAddresses(addresses); - gracefulSwitchLb.handleNameResolutionError(Status.CANCELLED); - verify(lb0, never()).handleNameResolutionError(Status.CANCELLED); - verify(lb1, never()).handleNameResolutionError(Status.CANCELLED); - verify(lb2).handleNameResolutionError(Status.CANCELLED); - - verifyNoMoreInteractions(lb0, lb1, lb2); - } - - @Test - @Deprecated - public void switchTo_shutdownTriggeredWhenSwitchAndForwardedWhenSwitchLbShutdown() { - gracefulSwitchLb.switchTo(lbPolicies[0]); - LoadBalancer lb0 = balancers.get(lbPolicies[0]); - Helper helper0 = helpers.get(lb0); - SubchannelPicker picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(READY, picker); - - gracefulSwitchLb.switchTo(lbPolicies[1]); - LoadBalancer lb1 = balancers.get(lbPolicies[1]); - verify(lb1, never()).shutdown(); - - gracefulSwitchLb.switchTo(lbPolicies[2]); - verify(lb1).shutdown(); - LoadBalancer lb2 = balancers.get(lbPolicies[2]); - verify(lb0, never()).shutdown(); - helpers.get(lb2).updateBalancingState(READY, mock(SubchannelPicker.class)); - verify(lb0).shutdown(); - - gracefulSwitchLb.switchTo(lbPolicies[3]); - LoadBalancer lb3 = balancers.get(lbPolicies[3]); - verify(lb2, never()).shutdown(); - verify(lb3, never()).shutdown(); - - gracefulSwitchLb.shutdown(); - verify(lb2).shutdown(); - verify(lb3).shutdown(); - - verifyNoMoreInteractions(lb0, lb1, lb2, lb3); - } - - @Test - @Deprecated - public void switchTo_requestConnectionForwardedToLatestPolicies() { - gracefulSwitchLb.switchTo(lbPolicies[0]); - LoadBalancer lb0 = balancers.get(lbPolicies[0]); - Helper helper0 = helpers.get(lb0); - SubchannelPicker picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(READY, picker); - - gracefulSwitchLb.requestConnection(); - verify(lb0).requestConnection(); - - gracefulSwitchLb.switchTo(lbPolicies[1]); - LoadBalancer lb1 = balancers.get(lbPolicies[1]); - gracefulSwitchLb.requestConnection(); - verify(lb1).requestConnection(); - - gracefulSwitchLb.switchTo(lbPolicies[2]); - verify(lb1).shutdown(); - LoadBalancer lb2 = balancers.get(lbPolicies[2]); - gracefulSwitchLb.requestConnection(); - verify(lb2).requestConnection(); - - // lb2 reports READY - helpers.get(lb2).updateBalancingState(READY, mock(SubchannelPicker.class)); - verify(lb0).shutdown(); - - gracefulSwitchLb.requestConnection(); - verify(lb2, times(2)).requestConnection(); - - gracefulSwitchLb.switchTo(lbPolicies[3]); - LoadBalancer lb3 = balancers.get(lbPolicies[3]); - gracefulSwitchLb.requestConnection(); - verify(lb3).requestConnection(); - - verifyNoMoreInteractions(lb0, lb1, lb2, lb3); - } - - @Test - @Deprecated - public void switchTo_createSubchannelForwarded() { - gracefulSwitchLb.switchTo(lbPolicies[0]); - LoadBalancer lb0 = balancers.get(lbPolicies[0]); - Helper helper0 = helpers.get(lb0); - SubchannelPicker picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(READY, picker); - - CreateSubchannelArgs createSubchannelArgs = newFakeCreateSubchannelArgs(); - helper0.createSubchannel(createSubchannelArgs); - verify(mockHelper).createSubchannel(createSubchannelArgs); - - gracefulSwitchLb.switchTo(lbPolicies[1]); - LoadBalancer lb1 = balancers.get(lbPolicies[1]); - Helper helper1 = helpers.get(lb1); - createSubchannelArgs = newFakeCreateSubchannelArgs(); - helper1.createSubchannel(createSubchannelArgs); - verify(mockHelper).createSubchannel(createSubchannelArgs); - - createSubchannelArgs = newFakeCreateSubchannelArgs(); - helper0.createSubchannel(createSubchannelArgs); - verify(mockHelper).createSubchannel(createSubchannelArgs); - - verifyNoMoreInteractions(lb0, lb1); - } - - @Test - @Deprecated - public void switchTo_updateBalancingStateIsGraceful() { - gracefulSwitchLb.switchTo(lbPolicies[0]); - LoadBalancer lb0 = balancers.get(lbPolicies[0]); - Helper helper0 = helpers.get(lb0); - SubchannelPicker picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(READY, picker); - verify(mockHelper).updateBalancingState(READY, picker); - - gracefulSwitchLb.switchTo(lbPolicies[1]); - LoadBalancer lb1 = balancers.get(lbPolicies[1]); - Helper helper1 = helpers.get(lb1); - picker = mock(SubchannelPicker.class); - helper1.updateBalancingState(CONNECTING, picker); - verify(mockHelper, never()).updateBalancingState(CONNECTING, picker); - - gracefulSwitchLb.switchTo(lbPolicies[2]); - verify(lb1).shutdown(); - LoadBalancer lb2 = balancers.get(lbPolicies[2]); - Helper helper2 = helpers.get(lb2); - picker = mock(SubchannelPicker.class); - helper2.updateBalancingState(CONNECTING, picker); - verify(mockHelper, never()).updateBalancingState(CONNECTING, picker); - - // lb2 reports READY - SubchannelPicker picker2 = mock(SubchannelPicker.class); - helper2.updateBalancingState(READY, picker2); - verify(lb0).shutdown(); - verify(mockHelper).updateBalancingState(READY, picker2); - - gracefulSwitchLb.switchTo(lbPolicies[3]); - LoadBalancer lb3 = balancers.get(lbPolicies[3]); - Helper helper3 = helpers.get(lb3); - SubchannelPicker picker3 = mock(SubchannelPicker.class); - helper3.updateBalancingState(CONNECTING, picker3); - verify(mockHelper, never()).updateBalancingState(CONNECTING, picker3); - - // lb2 out of READY - picker2 = mock(SubchannelPicker.class); - helper2.updateBalancingState(CONNECTING, picker2); - verify(mockHelper, never()).updateBalancingState(CONNECTING, picker2); - verify(mockHelper).updateBalancingState(CONNECTING, picker3); - verify(lb2).shutdown(); - - picker3 = mock(SubchannelPicker.class); - helper3.updateBalancingState(CONNECTING, picker3); - verify(mockHelper).updateBalancingState(CONNECTING, picker3); - - verifyNoMoreInteractions(lb0, lb1, lb2, lb3); - } - - @Test - @Deprecated - public void switchTo_switchWhileOldPolicyIsNotReady() { - gracefulSwitchLb.switchTo(lbPolicies[0]); - LoadBalancer lb0 = balancers.get(lbPolicies[0]); - Helper helper0 = helpers.get(lb0); - SubchannelPicker picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(READY, picker); - picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(CONNECTING, picker); - - verify(lb0, never()).shutdown(); - gracefulSwitchLb.switchTo(lbPolicies[1]); - verify(lb0).shutdown(); - LoadBalancer lb1 = balancers.get(lbPolicies[1]); - - Helper helper1 = helpers.get(lb1); - picker = mock(SubchannelPicker.class); - helper1.updateBalancingState(CONNECTING, picker); - verify(mockHelper).updateBalancingState(CONNECTING, picker); - - verify(lb1, never()).shutdown(); - gracefulSwitchLb.switchTo(lbPolicies[2]); - verify(lb1).shutdown(); - LoadBalancer lb2 = balancers.get(lbPolicies[2]); - - verifyNoMoreInteractions(lb0, lb1, lb2); - } - - @Test - @Deprecated - public void switchTo_switchWhileOldPolicyGoesFromReadyToNotReady() { - gracefulSwitchLb.switchTo(lbPolicies[0]); - LoadBalancer lb0 = balancers.get(lbPolicies[0]); - Helper helper0 = helpers.get(lb0); - SubchannelPicker picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(READY, picker); - - gracefulSwitchLb.switchTo(lbPolicies[1]); - verify(lb0, never()).shutdown(); - - LoadBalancer lb1 = balancers.get(lbPolicies[1]); - Helper helper1 = helpers.get(lb1); - SubchannelPicker picker1 = mock(SubchannelPicker.class); - helper1.updateBalancingState(CONNECTING, picker1); - verify(mockHelper, never()).updateBalancingState(CONNECTING, picker1); - - picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(CONNECTING, picker); - verify(lb0).shutdown(); - verify(mockHelper, never()).updateBalancingState(CONNECTING, picker); - verify(mockHelper).updateBalancingState(CONNECTING, picker1); - - picker1 = mock(SubchannelPicker.class); - helper1.updateBalancingState(READY, picker1); - verify(mockHelper).updateBalancingState(READY, picker1); - - verifyNoMoreInteractions(lb0, lb1); - } - - @Test - @Deprecated - public void switchTo_switchWhileOldPolicyGoesFromReadyToNotReadyWhileNewPolicyStillIdle() { - gracefulSwitchLb.switchTo(lbPolicies[0]); - LoadBalancer lb0 = balancers.get(lbPolicies[0]); - InOrder inOrder = inOrder(lb0, mockHelper); - Helper helper0 = helpers.get(lb0); - SubchannelPicker picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(READY, picker); - - gracefulSwitchLb.switchTo(lbPolicies[1]); - verify(lb0, never()).shutdown(); - - LoadBalancer lb1 = balancers.get(lbPolicies[1]); - Helper helper1 = helpers.get(lb1); - - picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(CONNECTING, picker); - - verify(mockHelper, never()).updateBalancingState(CONNECTING, picker); - inOrder.verify(mockHelper).updateBalancingState(CONNECTING, BUFFER_PICKER); - inOrder.verify(lb0).shutdown(); // shutdown after update - - picker = mock(SubchannelPicker.class); - helper1.updateBalancingState(CONNECTING, picker); - inOrder.verify(mockHelper).updateBalancingState(CONNECTING, picker); - - inOrder.verifyNoMoreInteractions(); - verifyNoMoreInteractions(lb1); - } - - @Test - @Deprecated - public void switchTo_newPolicyNameTheSameAsPendingPolicy_shouldHaveNoEffect() { - gracefulSwitchLb.switchTo(lbPolicies[0]); - LoadBalancer lb0 = balancers.get(lbPolicies[0]); - Helper helper0 = helpers.get(lb0); - SubchannelPicker picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(READY, picker); - - gracefulSwitchLb.switchTo(lbPolicies[1]); - LoadBalancer lb1 = balancers.get(lbPolicies[1]); - - gracefulSwitchLb.switchTo(lbPolicies[1]); - assertThat(balancers.get(lbPolicies[1])).isSameInstanceAs(lb1); - - verifyNoMoreInteractions(lb0, lb1); - } - - @Test - @Deprecated - public void switchTo_newPolicyNameTheSameAsCurrentPolicy_shouldShutdownPendingLb() { - gracefulSwitchLb.switchTo(lbPolicies[0]); - LoadBalancer lb0 = balancers.get(lbPolicies[0]); - - gracefulSwitchLb.switchTo(lbPolicies[0]); - assertThat(balancers.get(lbPolicies[0])).isSameInstanceAs(lb0); - - Helper helper0 = helpers.get(lb0); - SubchannelPicker picker = mock(SubchannelPicker.class); - helper0.updateBalancingState(READY, picker); - - gracefulSwitchLb.switchTo(lbPolicies[1]); - LoadBalancer lb1 = balancers.get(lbPolicies[1]); - - gracefulSwitchLb.switchTo(lbPolicies[0]); - verify(lb1).shutdown(); - assertThat(balancers.get(lbPolicies[0])).isSameInstanceAs(lb0); - - verifyNoMoreInteractions(lb0, lb1); - } - - - @Test - @Deprecated - public void switchTo_newLbFactoryEqualToOldOneShouldHaveNoEffect() { - final List balancers = new ArrayList<>(); - - final class LoadBalancerFactoryWithId extends LoadBalancer.Factory { - final int id; - - LoadBalancerFactoryWithId(int id) { - this.id = id; - } - - @Override - public LoadBalancer newLoadBalancer(Helper helper) { - LoadBalancer balancer = mock(LoadBalancer.class); - balancers.add(balancer); - return balancer; - } - - @Override - public boolean equals(Object o) { - if (!(o instanceof LoadBalancerFactoryWithId)) { - return false; - } - LoadBalancerFactoryWithId that = (LoadBalancerFactoryWithId) o; - return id == that.id; - } - - @Override - public int hashCode() { - return id; - } - } - - gracefulSwitchLb.switchTo(new LoadBalancerFactoryWithId(0)); - assertThat(balancers).hasSize(1); - LoadBalancer lb0 = balancers.get(0); - - gracefulSwitchLb.switchTo(new LoadBalancerFactoryWithId(0)); - assertThat(balancers).hasSize(1); - - gracefulSwitchLb.switchTo(new LoadBalancerFactoryWithId(1)); - assertThat(balancers).hasSize(2); - LoadBalancer lb1 = balancers.get(1); - verify(lb0).shutdown(); - - verifyNoMoreInteractions(lb0, lb1); - } - - // END OF OLD TESTS - @Test public void transientFailureOnInitialResolutionError() { gracefulSwitchLb.handleNameResolutionError(Status.DATA_LOSS); @@ -548,11 +97,12 @@ public void handleSubchannelState_shouldThrow() { .build())); Subchannel subchannel = mock(Subchannel.class); ConnectivityStateInfo connectivityStateInfo = ConnectivityStateInfo.forNonError(READY); - thrown.expect(UnsupportedOperationException.class); - gracefulSwitchLb.handleSubchannelState(subchannel, connectivityStateInfo); + assertThrows(UnsupportedOperationException.class, + () -> gracefulSwitchLb.handleSubchannelState(subchannel, connectivityStateInfo)); } @Test + @Deprecated public void canHandleEmptyAddressListFromNameResolutionForwardedToLatestPolicy() { assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) @@ -587,6 +137,7 @@ public void canHandleEmptyAddressListFromNameResolutionForwardedToLatestPolicy() assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isTrue(); } + @Deprecated @Test public void handleResolvedAddressesAndNameResolutionErrorForwardedToLatestPolicy() { ResolvedAddresses addresses = newFakeAddresses(); @@ -810,7 +361,21 @@ public void createSubchannelForwarded() { } @Test - public void updateBalancingStateIsGraceful() { + public void updateBalancingStateIsGraceful_Ready() { + updateBalancingStateIsGraceful(READY); + } + + @Test + public void updateBalancingStateIsGraceful_TransientFailure() { + updateBalancingStateIsGraceful(TRANSIENT_FAILURE); + } + + @Test + public void updateBalancingStateIsGraceful_Idle() { + updateBalancingStateIsGraceful(IDLE); + } + + public void updateBalancingStateIsGraceful(ConnectivityState swapsOnState) { assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) .build())); @@ -839,11 +404,11 @@ public void updateBalancingStateIsGraceful() { helper2.updateBalancingState(CONNECTING, picker); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker); - // lb2 reports READY + // lb2 reports swapsOnState SubchannelPicker picker2 = mock(SubchannelPicker.class); - helper2.updateBalancingState(READY, picker2); + helper2.updateBalancingState(swapsOnState, picker2); verify(lb0).shutdown(); - verify(mockHelper).updateBalancingState(READY, picker2); + verify(mockHelper).updateBalancingState(swapsOnState, picker2); assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() .setLoadBalancingPolicyConfig(createConfig(lbPolicies[3], new Object())) @@ -854,7 +419,7 @@ public void updateBalancingStateIsGraceful() { helper3.updateBalancingState(CONNECTING, picker3); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker3); - // lb2 out of READY + // lb2 out of swapsOnState picker2 = mock(SubchannelPicker.class); helper2.updateBalancingState(CONNECTING, picker2); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker2); @@ -952,7 +517,11 @@ public void switchWhileOldPolicyGoesFromReadyToNotReadyWhileNewPolicyStillIdle() helper0.updateBalancingState(CONNECTING, picker); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker); - inOrder.verify(mockHelper).updateBalancingState(CONNECTING, BUFFER_PICKER); + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + assertThat(pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)).hasResult()) + .isFalse(); + inOrder.verify(lb0).shutdown(); // shutdown after update picker = mock(SubchannelPicker.class); diff --git a/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java b/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java index 6bfd6d7a659..14dc8518756 100644 --- a/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java @@ -52,7 +52,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Collectors; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -153,8 +152,7 @@ public void pickAfterResolvedUpdatedHosts() { LoadBalancer.Subchannel removedSubchannel = getSubchannel(removedEag); LoadBalancer.Subchannel oldSubchannel = getSubchannel(oldEag1); LoadBalancer.SubchannelStateListener removedListener = - testHelperInst.getSubchannelStateListeners() - .get(testHelperInst.getRealForMockSubChannel(removedSubchannel)); + testHelperInst.getSubchannelStateListener(removedSubchannel); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -168,8 +166,6 @@ public void pickAfterResolvedUpdatedHosts() { verify(removedSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).requestConnection(); - assertThat(getChildEags(loadBalancer)).containsExactly(removedEag, oldEag1); - // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); @@ -186,10 +182,10 @@ public void pickAfterResolvedUpdatedHosts() { removedListener.onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(getChildEags(loadBalancer)).containsExactly(oldEag2, newEag); - verify(mockHelper, times(3)).createSubchannel(any(LoadBalancer.CreateSubchannelArgs.class)); inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); + picker = pickerCaptor.getValue(); + assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel); AbstractTestHelper.verifyNoMoreMeaningfulInteractions(mockHelper); } @@ -268,6 +264,42 @@ public void testEndpoint_equals() { .testEquals(); } + @Test + public void offsetIterable_positive() { + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(1, 2, 3, 4), 9)) + .containsExactly(2, 3, 4, 1) + .inOrder(); + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(1, 2, 3, 4, 5), 9)) + .containsExactly(5, 1, 2, 3, 4) + .inOrder(); + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(1, 2, 3), 3)) + .containsExactly(1, 2, 3) + .inOrder(); + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(1, 2, 3), 0)) + .containsExactly(1, 2, 3) + .inOrder(); + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(1), 123)) + .containsExactly(1) + .inOrder(); + } + + @Test + public void offsetIterable_negative() { + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(1, 2, 3, 4), -1)) + .containsExactly(4, 1, 2, 3) + .inOrder(); + } + + @Test + public void offsetIterable_empty() { + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(), 1)) + .isEmpty(); + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(), 0)) + .isEmpty(); + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(), -1)) + .isEmpty(); + } + private String addressesOnlyString(EquivalentAddressGroup eag) { if (eag == null) { return null; @@ -328,12 +360,6 @@ private LoadBalancer.Subchannel getSubchannel(EquivalentAddressGroup eag) { return null; } - private static List getChildEags(MultiChildLoadBalancer loadBalancer) { - return loadBalancer.getChildLbStates().stream() - .map(ChildLbState::getEag) - .collect(Collectors.toList()); - } - private void deliverSubchannelState(LoadBalancer.Subchannel subchannel, ConnectivityStateInfo newState) { testHelperInst.deliverSubchannelState(subchannel, newState); @@ -348,13 +374,16 @@ protected TestLb(Helper mockHelper) { protected void updateOverallBalancingState() { ConnectivityState overallState = null; final Map childPickers = new HashMap<>(); + final Map childConnStates = new HashMap<>(); for (ChildLbState childLbState : getChildLbStates()) { childPickers.put(childLbState.getKey(), childLbState.getCurrentPicker()); + childConnStates.put(childLbState.getKey(), childLbState.getCurrentState()); overallState = aggregateState(overallState, childLbState.getCurrentState()); } if (overallState != null) { - getHelper().updateBalancingState(overallState, new TestSubchannelPicker(childPickers)); + getHelper().updateBalancingState( + overallState, new TestSubchannelPicker(childPickers, childConnStates)); currentConnectivityState = overallState; } @@ -364,18 +393,17 @@ private class TestSubchannelPicker extends SubchannelPicker { Map childPickerMap; Map childStates = new HashMap<>(); - TestSubchannelPicker(Map childPickers) { - childPickerMap = childPickers; - for (Object key : childPickerMap.keySet()) { - childStates.put(key, getChildLbState(key).getCurrentState()); - } + TestSubchannelPicker( + Map childPickers, Map childStates) { + this.childPickerMap = childPickers; + this.childStates = childStates; } List getReadySubchannels() { List readySubchannels = new ArrayList<>(); for ( Map.Entry cur : childStates.entrySet()) { if (cur.getValue() == READY) { - Subchannel s = subchannels.get(Arrays.asList(getChildLbState(cur.getKey()).getEag())); + Subchannel s = childPickerMap.get(cur.getKey()).pickSubchannel(null).getSubchannel(); readySubchannels.add(s); } } diff --git a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java index 1b0139affef..39f5b5fb7d6 100644 --- a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java @@ -54,7 +54,6 @@ import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; import io.grpc.internal.FakeClock.ScheduledTask; -import io.grpc.internal.PickFirstLoadBalancerProvider; import io.grpc.internal.TestUtils.StandardLoadBalancerProvider; import io.grpc.util.OutlierDetectionLoadBalancer.EndpointTracker; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; @@ -227,7 +226,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { when(mockStreamTracerFactory.newClientStreamTracer(any(), any())).thenReturn(mockStreamTracer); - loadBalancer = new OutlierDetectionLoadBalancer(mockHelper, fakeClock.getTimeProvider()); + loadBalancer = new OutlierDetectionLoadBalancer(mockHelper, fakeClock.getTicker()); } @Test @@ -280,7 +279,7 @@ public void acceptResolvedAddresses() { loadBalancer.acceptResolvedAddresses(resolvedAddresses); // Handling of resolved addresses is delegated - verify(mockChildLb).handleResolvedAddresses( + verify(mockChildLb).acceptResolvedAddresses( resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build()); // There is a single pending task to run the outlier detection algorithm @@ -409,7 +408,7 @@ public void delegatePick() throws Exception { // Make sure that we can pick the single READY subchannel. SubchannelPicker picker = pickerCaptor.getAllValues().get(2); PickResult pickResult = picker.pickSubchannel(mock(PickSubchannelArgs.class)); - Subchannel s = ((OutlierDetectionSubchannel) pickResult.getSubchannel()).delegate(); + Subchannel s = pickResult.getSubchannel(); if (s instanceof HealthProducerHelper.HealthProducerSubchannel) { s = ((HealthProducerHelper.HealthProducerSubchannel) s).delegate(); } @@ -568,9 +567,7 @@ public void successRateOneOutlier_configChange() { loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); - // The PickFirstLeafLB has an extra level of indirection because of health - int expectedStateChanges = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 8 : 12; - generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), expectedStateChanges); + generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 8); // Move forward in time to a point where the detection timer has fired. forwardTime(config); @@ -604,8 +601,7 @@ public void successRateOneOutlier_unejected() { assertEjectedSubchannels(ImmutableSet.of(ImmutableSet.copyOf(servers.get(0).getAddresses()))); // Now we produce more load, but the subchannel has started working and is no longer an outlier. - int expectedStateChanges = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 8 : 12; - generateLoad(ImmutableMap.of(), expectedStateChanges); + generateLoad(ImmutableMap.of(), 8); // Move forward in time to a point where the detection timer has fired. fakeClock.forwardTime(config.maxEjectionTimeNanos + 1, TimeUnit.NANOSECONDS); diff --git a/util/src/test/java/io/grpc/util/RandomSubsettingLoadBalancerProviderTest.java b/util/src/test/java/io/grpc/util/RandomSubsettingLoadBalancerProviderTest.java new file mode 100644 index 00000000000..18a0766d4b2 --- /dev/null +++ b/util/src/test/java/io/grpc/util/RandomSubsettingLoadBalancerProviderTest.java @@ -0,0 +1,135 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +import io.grpc.InternalServiceProviders; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status; +import io.grpc.internal.JsonParser; +import io.grpc.util.RandomSubsettingLoadBalancer.RandomSubsettingLoadBalancerConfig; +import java.io.IOException; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class RandomSubsettingLoadBalancerProviderTest { + private final RandomSubsettingLoadBalancerProvider provider = + new RandomSubsettingLoadBalancerProvider(); + + @Test + public void registered() { + for (LoadBalancerProvider current : + InternalServiceProviders.getCandidatesViaServiceLoader( + LoadBalancerProvider.class, getClass().getClassLoader())) { + if (current instanceof RandomSubsettingLoadBalancerProvider) { + return; + } + } + fail("RandomSubsettingLoadBalancerProvider not registered"); + } + + @Test + public void providesLoadBalancer() { + Helper helper = mock(Helper.class); + assertThat(provider.newLoadBalancer(helper)) + .isInstanceOf(RandomSubsettingLoadBalancer.class); + } + + @Test + public void parseConfigRequiresSubsetSize() throws IOException { + String emptyConfig = "{}"; + + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(emptyConfig)); + assertThat(configOrError.getError()).isNotNull(); + assertThat(configOrError.getError().toString()) + .isEqualTo( + Status.UNAVAILABLE + .withDescription( + "Subset size missing in random_subsetting_experimental, LB policy config={}") + .toString()); + } + + @Test + public void parseConfigReturnsErrorWhenChildPolicyMissing() throws IOException { + String missingChildPolicyConfig = "{\"subsetSize\": 3}"; + + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(missingChildPolicyConfig)); + assertThat(configOrError.getError()).isNotNull(); + + Status error = configOrError.getError(); + assertThat(error.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(error.getDescription()).isEqualTo( + "Failed to parse child in random_subsetting_experimental" + + ", LB policy config={subsetSize=3.0}"); + assertThat(error.getCause().getMessage()).isEqualTo( + "UNAVAILABLE: No child LB config specified"); + } + + @Test + public void parseConfigReturnsErrorWhenChildPolicyInvalid() throws IOException { + String invalidChildPolicyConfig = + "{" + + "\"subsetSize\": 3, " + + "\"childPolicy\" : [{\"random_policy\" : {}}]" + + "}"; + + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(invalidChildPolicyConfig)); + assertThat(configOrError.getError()).isNotNull(); + + Status error = configOrError.getError(); + assertThat(error.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(error.getDescription()).isEqualTo( + "Failed to parse child in random_subsetting_experimental, LB policy config=" + + "{subsetSize=3.0, childPolicy=[{random_policy={}}]}"); + assertThat(error.getCause().getMessage()).contains( + "UNAVAILABLE: None of [random_policy] specified by Service Config are available."); + } + + @Test + public void parseValidConfig() throws IOException { + String validConfig = + "{" + + "\"subsetSize\": 3, " + + "\"childPolicy\" : [{\"round_robin\" : {}}]" + + "}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(validConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + + RandomSubsettingLoadBalancerConfig actualConfig = + (RandomSubsettingLoadBalancerConfig) configOrError.getConfig(); + assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider( + actualConfig.childConfig).getPolicyName()).isEqualTo("round_robin"); + assertThat(actualConfig.subsetSize).isEqualTo(3); + } + + @SuppressWarnings("unchecked") + private static Map parseJsonObject(String json) throws IOException { + return (Map) JsonParser.parse(json); + } +} diff --git a/util/src/test/java/io/grpc/util/RandomSubsettingLoadBalancerTest.java b/util/src/test/java/io/grpc/util/RandomSubsettingLoadBalancerTest.java new file mode 100644 index 00000000000..2c43e8f4c3a --- /dev/null +++ b/util/src/test/java/io/grpc/util/RandomSubsettingLoadBalancerTest.java @@ -0,0 +1,333 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.LoadBalancerProvider; +import io.grpc.Status; +import io.grpc.internal.TestUtils; +import io.grpc.util.RandomSubsettingLoadBalancer.RandomSubsettingLoadBalancerConfig; +import java.net.SocketAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.mockito.stubbing.Answer; + +@RunWith(JUnit4.class) +public class RandomSubsettingLoadBalancerTest { + @Rule + public final MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock + private LoadBalancer.Helper mockHelper; + @Mock + private LoadBalancer mockChildLb; + @Mock + private SocketAddress mockSocketAddress; + + @Captor + private ArgumentCaptor resolvedAddrCaptor; + + private BackendDetails backendDetails; + + private RandomSubsettingLoadBalancer loadBalancer; + + private final LoadBalancerProvider mockChildLbProvider = + new TestUtils.StandardLoadBalancerProvider("foo_policy") { + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return mockChildLb; + } + }; + + private final LoadBalancerProvider roundRobinLbProvider = + new TestUtils.StandardLoadBalancerProvider("round_robin") { + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return new RoundRobinLoadBalancer(helper); + } + }; + + private Object newChildConfig(LoadBalancerProvider provider, Object config) { + return GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(provider, config); + } + + private RandomSubsettingLoadBalancerConfig createRandomSubsettingLbConfig( + int subsetSize, LoadBalancerProvider childLbProvider, Object childConfig) { + return new RandomSubsettingLoadBalancer.RandomSubsettingLoadBalancerConfig.Builder() + .setSubsetSize(subsetSize) + .setChildConfig(newChildConfig(childLbProvider, childConfig)) + .build(); + } + + private BackendDetails setupBackends(int backendCount) { + List servers = Lists.newArrayList(); + Map, Subchannel> subchannels = Maps.newLinkedHashMap(); + + for (int i = 0; i < backendCount; i++) { + SocketAddress addr = new FakeSocketAddress("server" + i); + EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(addr); + servers.add(addressGroup); + Subchannel subchannel = mock(Subchannel.class); + subchannels.put(Arrays.asList(addressGroup), subchannel); + } + + return new BackendDetails(servers, subchannels); + } + + @Before + public void setUp() { + int seed = 0; + loadBalancer = new RandomSubsettingLoadBalancer(mockHelper, seed); + + int backendSize = 5; + backendDetails = setupBackends(backendSize); + } + + @Test + public void handleNameResolutionError() { + int subsetSize = 2; + Object childConfig = "someConfig"; + + RandomSubsettingLoadBalancerConfig config = createRandomSubsettingLbConfig( + subsetSize, mockChildLbProvider, childConfig); + + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of(new EquivalentAddressGroup(mockSocketAddress))) + .setLoadBalancingPolicyConfig(config) + .build()); + + loadBalancer.handleNameResolutionError(Status.DEADLINE_EXCEEDED); + verify(mockChildLb).handleNameResolutionError(Status.DEADLINE_EXCEEDED); + } + + @Test + public void shutdown() { + int subsetSize = 2; + Object childConfig = "someConfig"; + + RandomSubsettingLoadBalancerConfig config = createRandomSubsettingLbConfig( + subsetSize, mockChildLbProvider, childConfig); + + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of(new EquivalentAddressGroup(mockSocketAddress))) + .setLoadBalancingPolicyConfig(config) + .build()); + + loadBalancer.shutdown(); + verify(mockChildLb).shutdown(); + } + + @Test + public void acceptResolvedAddresses_mockedChildLbPolicy() { + int subsetSize = 3; + Object childConfig = "someConfig"; + + RandomSubsettingLoadBalancerConfig config = createRandomSubsettingLbConfig( + subsetSize, mockChildLbProvider, childConfig); + + ResolvedAddresses resolvedAddresses = + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.copyOf(backendDetails.servers)) + .setLoadBalancingPolicyConfig(config) + .build(); + + loadBalancer.acceptResolvedAddresses(resolvedAddresses); + + verify(mockChildLb).acceptResolvedAddresses(resolvedAddrCaptor.capture()); + assertThat(resolvedAddrCaptor.getValue().getAddresses().size()).isEqualTo(subsetSize); + assertThat(resolvedAddrCaptor.getValue().getLoadBalancingPolicyConfig()).isEqualTo(childConfig); + } + + @Test + public void acceptResolvedAddresses_roundRobinChildLbPolicy() { + int subsetSize = 3; + Object childConfig = null; + + RandomSubsettingLoadBalancerConfig config = createRandomSubsettingLbConfig( + subsetSize, roundRobinLbProvider, childConfig); + + ResolvedAddresses resolvedAddresses = + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.copyOf(backendDetails.servers)) + .setLoadBalancingPolicyConfig(config) + .build(); + + loadBalancer.acceptResolvedAddresses(resolvedAddresses); + + int insubset = 0; + for (Subchannel subchannel : backendDetails.subchannels.values()) { + LoadBalancer.SubchannelStateListener ssl = + backendDetails.subchannelStateListeners.get(subchannel); + if (ssl != null) { // it might be null if it's not in the subset. + insubset += 1; + ssl.onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + } + } + + assertThat(insubset).isEqualTo(subsetSize); + } + + // verifies: https://github.com/grpc/proposal/blob/master/A68_graphics/subsetting100-100-5.png + @Test + public void backendsCanBeDistributedEvenly_subsetting100_100_5() { + verifyConnectionsByServer(100, 100, 5, 15); + } + + // verifies https://github.com/grpc/proposal/blob/master/A68_graphics/subsetting100-100-25.png + @Test + public void backendsCanBeDistributedEvenly_subsetting100_100_25() { + verifyConnectionsByServer(100, 100, 25, 40); + } + + // verifies: https://github.com/grpc/proposal/blob/master/A68_graphics/subsetting100-10-5.png + @Test + public void backendsCanBeDistributedEvenly_subsetting100_10_5() { + verifyConnectionsByServer(100, 10, 5, 65); + } + + // verifies: https://github.com/grpc/proposal/blob/master/A68_graphics/subsetting500-10-5.png + @Test + public void backendsCanBeDistributedEvenly_subsetting500_10_5() { + verifyConnectionsByServer(500, 10, 5, 600); + } + + // verifies: https://github.com/grpc/proposal/blob/master/A68_graphics/subsetting2000-10-5.png + @Test + public void backendsCanBeDistributedEvenly_subsetting2000_100_5() { + verifyConnectionsByServer(2000, 10, 5, 1200); + } + + public void verifyConnectionsByServer( + int clientsCount, int serversCount, int subsetSize, int expectedMaxConnections) { + backendDetails = setupBackends(serversCount); + Object childConfig = "someConfig"; + + List configs = Lists.newArrayList(); + for (int i = 0; i < clientsCount; i++) { + configs.add(createRandomSubsettingLbConfig(subsetSize, mockChildLbProvider, childConfig)); + } + + Map connectionsByServer = Maps.newLinkedHashMap(); + + for (int i = 0; i < clientsCount; i++) { + RandomSubsettingLoadBalancerConfig config = configs.get(i); + + ResolvedAddresses resolvedAddresses = + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.copyOf(backendDetails.servers)) + .setLoadBalancingPolicyConfig(config) + .build(); + + loadBalancer = new RandomSubsettingLoadBalancer(mockHelper, i); + loadBalancer.acceptResolvedAddresses(resolvedAddresses); + + verify(mockChildLb, atLeastOnce()).acceptResolvedAddresses(resolvedAddrCaptor.capture()); + // Verify ChildLB is only getting subsetSize ResolvedAddresses each time + assertThat(resolvedAddrCaptor.getValue().getAddresses().size()).isEqualTo(config.subsetSize); + + for (EquivalentAddressGroup eag : resolvedAddrCaptor.getValue().getAddresses()) { + for (SocketAddress addr : eag.getAddresses()) { + Integer prev = connectionsByServer.getOrDefault(addr, 0); + connectionsByServer.put(addr, prev + 1); + } + } + } + + int maxConnections = Collections.max(connectionsByServer.values()); + + assertThat(maxConnections).isAtMost(expectedMaxConnections); + } + + private class BackendDetails { + private final List servers; + private final Map, Subchannel> subchannels; + private final Map subchannelStateListeners; + + BackendDetails(List servers, + Map, Subchannel> subchannels) { + this.servers = servers; + this.subchannels = subchannels; + this.subchannelStateListeners = Maps.newLinkedHashMap(); + + when(mockHelper.createSubchannel(any(LoadBalancer.CreateSubchannelArgs.class))).then( + new Answer() { + @Override + public Subchannel answer(InvocationOnMock invocation) throws Throwable { + CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; + final Subchannel subchannel = backendDetails.subchannels.get(args.getAddresses()); + when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); + when(subchannel.getAttributes()).thenReturn(args.getAttributes()); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + subchannelStateListeners.put(subchannel, + (SubchannelStateListener) invocation.getArguments()[0]); + return null; + } + }).when(subchannel).start(any(SubchannelStateListener.class)); + return subchannel; + } + }); + } + } + + private static class FakeSocketAddress extends SocketAddress { + final String name; + + FakeSocketAddress(String name) { + this.name = name; + } + + @Override + public String toString() { + return "FakeSocketAddress-" + name; + } + } +} diff --git a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index 743bbbef796..18854ca1bb6 100644 --- a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -22,7 +22,6 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.util.MultiChildLoadBalancer.IS_PETIOLE_POLICY; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; @@ -54,8 +53,9 @@ import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.Status; +import io.grpc.internal.PickFirstLoadBalancerProvider; +import io.grpc.internal.PickFirstLoadBalancerProviderAccessor; import io.grpc.internal.TestUtils; -import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.util.RoundRobinLoadBalancer.ReadyPicker; import java.net.SocketAddress; import java.util.ArrayList; @@ -105,6 +105,7 @@ public class RoundRobinLoadBalancerTest { private ArgumentCaptor createArgsCaptor; private TestHelper testHelperInst = new TestHelper(); private Helper mockHelper = mock(Helper.class, delegatesTo(testHelperInst)); + private boolean defaultNewPickFirst = PickFirstLoadBalancerProvider.isEnabledNewPickFirst(); @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown(). private PickSubchannelArgs mockArgs; @@ -127,6 +128,7 @@ private Status acceptAddresses(List eagList, Attributes @After public void tearDown() throws Exception { + PickFirstLoadBalancerProviderAccessor.setEnableNewPickFirst(defaultNewPickFirst); verifyNoMoreInteractions(mockArgs); } @@ -202,16 +204,6 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { verify(removedSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).requestConnection(); - assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2); - for (ChildLbState childLbState : loadBalancer.getChildLbStates()) { - assertThat(childLbState.getResolvedAddresses().getAttributes().get(IS_PETIOLE_POLICY)) - .isTrue(); - } - assertThat(loadBalancer.getChildLbStateEag(removedEag).getCurrentPicker().pickSubchannel(null) - .getSubchannel()).isEqualTo(removedSubchannel); - assertThat(loadBalancer.getChildLbStateEag(oldEag1).getCurrentPicker().pickSubchannel(null) - .getSubchannel()).isEqualTo(oldSubchannel); - // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); @@ -225,12 +217,6 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2); - assertThat(loadBalancer.getChildLbStateEag(newEag).getCurrentPicker() - .pickSubchannel(null).getSubchannel()).isEqualTo(newSubchannel); - assertThat(loadBalancer.getChildLbStateEag(oldEag2).getCurrentPicker() - .pickSubchannel(null).getSubchannel()).isEqualTo(oldSubchannel); - verify(mockHelper, times(6)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); @@ -243,35 +229,35 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(mockHelper); - Status addressesAcceptanceStatus = acceptAddresses(servers, Attributes.EMPTY); + Status addressesAcceptanceStatus = + acceptAddresses(Arrays.asList(servers.get(0)), Attributes.EMPTY); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); + inOrder.verify(mockHelper).createSubchannel(any(CreateSubchannelArgs.class)); // TODO figure out if this method testing the right things - ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); - Subchannel subchannel = subchannels.get(Arrays.asList(childLbState.getEag())); + assertThat(subchannels).hasSize(1); + Subchannel subchannel = subchannels.values().iterator().next(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); - assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class); - assertThat(childLbState.getCurrentState()).isEqualTo(READY); Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"); deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); - assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); - AbstractTestHelper.refreshInvokedAndUpdateBS(inOrder, CONNECTING, mockHelper, pickerCaptor); - assertThat(pickerCaptor.getValue()).isEqualTo(EMPTY_PICKER); + AbstractTestHelper.refreshInvokedAndUpdateBS( + inOrder, TRANSIENT_FAILURE, mockHelper, pickerCaptor); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()).isEqualTo(error); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).refreshNameResolution(); - assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); + inOrder.verify(mockHelper, never()) + .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); verify(subchannel, atLeastOnce()).requestConnection(); - verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); AbstractTestHelper.verifyNoMoreMeaningfulInteractions(mockHelper); } @@ -282,10 +268,10 @@ public void ignoreShutdownSubchannelStateChange() { assertThat(addressesAcceptanceStatus.isOk()).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); + List savedSubchannels = new ArrayList<>(subchannels.values()); loadBalancer.shutdown(); - for (ChildLbState child : loadBalancer.getChildLbStates()) { - Subchannel sc = child.getCurrentPicker().pickSubchannel(null).getSubchannel(); - verify(child).shutdown(); + for (Subchannel sc : savedSubchannels) { + verify(sc).shutdown(); // When the subchannel is being shut down, a SHUTDOWN connectivity state is delivered // back to the subchannel state listener. deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(SHUTDOWN)); @@ -300,34 +286,27 @@ public void stayTransientFailureUntilReady() { Status addressesAcceptanceStatus = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); + inOrder.verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); - Map childToSubChannelMap = new HashMap<>(); // Simulate state transitions for each subchannel individually. - for ( ChildLbState child : loadBalancer.getChildLbStates()) { - Subchannel sc = subchannels.get(Arrays.asList(child.getEag())); - childToSubChannelMap.put(child, sc); + for (Subchannel sc : subchannels.values()) { Status error = Status.UNKNOWN.withDescription("connection broken"); deliverSubchannelState( sc, ConnectivityStateInfo.forTransientFailure(error)); - assertEquals(TRANSIENT_FAILURE, child.getCurrentState()); deliverSubchannelState( sc, ConnectivityStateInfo.forNonError(CONNECTING)); - assertEquals(TRANSIENT_FAILURE, child.getCurrentState()); } inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(ReadyPicker.class)); inOrder.verify(mockHelper, atLeast(0)).refreshNameResolution(); inOrder.verifyNoMoreInteractions(); - ChildLbState child = loadBalancer.getChildLbStates().iterator().next(); - Subchannel subchannel = childToSubChannelMap.get(child); + Subchannel subchannel = subchannels.values().iterator().next(); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(child.getCurrentState()).isEqualTo(READY); inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); - verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper, atLeast(0)).refreshNameResolution(); inOrder.verifyNoMoreInteractions(); } @@ -342,8 +321,7 @@ public void refreshNameResolutionWhenSubchannelConnectionBroken() { inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); // Simulate state transitions for each subchannel individually. - for (ChildLbState child : loadBalancer.getChildLbStates()) { - Subchannel sc = subchannels.get(Arrays.asList(child.getEag())); + for (Subchannel sc : subchannels.values()) { verify(sc).requestConnection(); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); Status error = Status.UNKNOWN.withDescription("connection broken"); @@ -483,6 +461,60 @@ public void subchannelStateIsolation() throws Exception { assertThat(pickers.hasNext()).isFalse(); } + @Test + public void subchannelHealthObserved() throws Exception { + // Only the new PF policy observes the new separate listener for health + PickFirstLoadBalancerProviderAccessor.setEnableNewPickFirst(true); + // PickFirst does most of this work. If the test fails, check IS_PETIOLE_POLICY + Map healthListeners = new HashMap<>(); + loadBalancer = new RoundRobinLoadBalancer(new ForwardingLoadBalancerHelper() { + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + Subchannel subchannel = super.createSubchannel(args.toBuilder() + .setAttributes(args.getAttributes().toBuilder() + .set(LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY, true) + .build()) + .build()); + healthListeners.put( + subchannel, args.getOption(LoadBalancer.HEALTH_CONSUMER_LISTENER_ARG_KEY)); + return subchannel; + } + + @Override + protected Helper delegate() { + return mockHelper; + } + }); + + InOrder inOrder = inOrder(mockHelper); + Status addressesAcceptanceStatus = acceptAddresses(servers, Attributes.EMPTY); + assertThat(addressesAcceptanceStatus.isOk()).isTrue(); + Subchannel subchannel0 = subchannels.get(Arrays.asList(servers.get(0))); + Subchannel subchannel1 = subchannels.get(Arrays.asList(servers.get(1))); + Subchannel subchannel2 = subchannels.get(Arrays.asList(servers.get(2))); + + // Subchannels go READY, but the LB waits for health + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + } + inOrder.verify(mockHelper, times(0)) + .updateBalancingState(eq(READY), any(SubchannelPicker.class)); + + // Health results lets subchannels go READY + healthListeners.get(subchannel0).onSubchannelState( + ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription("oh no"))); + healthListeners.get(subchannel1).onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + healthListeners.get(subchannel2).onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); + SubchannelPicker picker = pickerCaptor.getValue(); + List picks = Arrays.asList( + picker.pickSubchannel(mockArgs).getSubchannel(), + picker.pickSubchannel(mockArgs).getSubchannel(), + picker.pickSubchannel(mockArgs).getSubchannel(), + picker.pickSubchannel(mockArgs).getSubchannel()); + assertThat(picks).containsExactly(subchannel1, subchannel2, subchannel1, subchannel2); + } + @Test public void readyPicker_emptyList() { // ready picker list must be non-empty diff --git a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java index bdeff9d17c5..837dc68c057 100644 --- a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java +++ b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java @@ -16,6 +16,7 @@ package io.grpc.util; +import static com.google.common.base.Preconditions.checkNotNull; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeast; @@ -23,7 +24,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import com.google.common.collect.Maps; import io.grpc.Attributes; import io.grpc.Channel; import io.grpc.ChannelLogger; @@ -55,6 +55,7 @@ * To use it replace
* \@mock Helper mockHelper
* with
+ * *

Helper mockHelper = mock(Helper.class, delegatesTo(new TestHelper()));

*
* TestHelper will need to define accessors for the maps that information is store within as @@ -62,10 +63,8 @@ */ public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper { - private final Map mockToRealSubChannelMap = new HashMap<>(); + private final Map mockToRealSubChannelMap = new HashMap<>(); protected final Map realToMockSubChannelMap = new HashMap<>(); - private final Map subchannelStateListeners = - Maps.newLinkedHashMap(); private final FakeClock fakeClock; private final SynchronizationContext syncContext; @@ -86,22 +85,14 @@ public AbstractTestHelper(FakeClock fakeClock, SynchronizationContext syncContex this.syncContext = syncContext; } - public Map getMockToRealSubChannelMap() { - return mockToRealSubChannelMap; - } - - public Subchannel getRealForMockSubChannel(Subchannel mock) { - Subchannel realSc = getMockToRealSubChannelMap().get(mock); + private TestSubchannel getRealForMockSubChannel(Subchannel mock) { + TestSubchannel realSc = mockToRealSubChannelMap.get(mock); if (realSc == null) { - realSc = mock; + realSc = (TestSubchannel) mock; } return realSc; } - public Map getSubchannelStateListeners() { - return subchannelStateListeners; - } - public static final FakeClock.TaskFilter NOT_START_NEXT_CONNECTION = new FakeClock.TaskFilter() { @Override @@ -115,15 +106,15 @@ public static int getNumFilteredPendingTasks(FakeClock fakeClock) { } public void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { - Subchannel realSc = getMockToRealSubChannelMap().get(subchannel); - if (realSc == null) { - realSc = subchannel; - } - SubchannelStateListener listener = getSubchannelStateListeners().get(realSc); + getSubchannelStateListener(subchannel).onSubchannelState(newState); + } + + public SubchannelStateListener getSubchannelStateListener(Subchannel subchannel) { + SubchannelStateListener listener = getRealForMockSubChannel(subchannel).listener; if (listener == null) { - throw new IllegalArgumentException("subchannel does not have a matching listener"); + throw new IllegalArgumentException("subchannel has not been started"); } - listener.onSubchannelState(newState); + return listener; } @Override @@ -143,7 +134,7 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { TestSubchannel delegate = createRealSubchannel(args); subchannel = mock(Subchannel.class, delegatesTo(delegate)); getSubchannelMap().put(args.getAddresses(), subchannel); - getMockToRealSubChannelMap().put(subchannel, delegate); + mockToRealSubChannelMap.put(subchannel, delegate); realToMockSubChannelMap.put(delegate, subchannel); } @@ -160,7 +151,7 @@ public void refreshNameResolution() { } public void setChannel(Subchannel subchannel, Channel channel) { - ((TestSubchannel)subchannel).channel = channel; + getRealForMockSubChannel(subchannel).channel = channel; } @Override @@ -207,6 +198,7 @@ public static void verifyNoMoreMeaningfulInteractions(Helper helper, InOrder inO protected class TestSubchannel extends ForwardingSubchannel { CreateSubchannelArgs args; + SubchannelStateListener listener; Channel channel; public TestSubchannel(CreateSubchannelArgs args) { @@ -249,12 +241,11 @@ public void updateAddresses(List addrs) { @Override public void start(SubchannelStateListener listener) { - getSubchannelStateListeners().put(this, listener); + this.listener = checkNotNull(listener, "listener"); } @Override public void shutdown() { - getSubchannelStateListeners().remove(this); for (EquivalentAddressGroup eag : getAllAddresses()) { getSubchannelMap().remove(Collections.singletonList(eag)); } diff --git a/xds/BUILD.bazel b/xds/BUILD.bazel index 05753a3a320..9a650485c6c 100644 --- a/xds/BUILD.bazel +++ b/xds/BUILD.bazel @@ -1,5 +1,9 @@ +load("@bazel_jar_jar//:jar_jar.bzl", "jar_jar") +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") +load("@rules_java//java:defs.bzl", "java_binary", "java_library", "java_test") load("@rules_jvm_external//:defs.bzl", "artifact") -load("//:java_grpc_library.bzl", "java_grpc_library") +load("//:java_grpc_library.bzl", "INTERNAL_java_grpc_library_for_xds", "java_grpc_library", "java_rpc_toolchain") # Mirrors the dependencies included in the artifact on Maven Central for usage # with maven_install's override_targets. Should only be used as a dep for @@ -13,30 +17,15 @@ java_library( ], ) +# Ordinary deps for :xds java_library( - name = "xds", - srcs = glob( - [ - "src/main/java/**/*.java", - "third_party/zero-allocation-hashing/main/java/**/*.java", - ], - exclude = ["src/main/java/io/grpc/xds/orca/**"], - ), - resources = glob([ - "src/main/resources/**", - ]), - visibility = ["//visibility:public"], - deps = [ - ":envoy_service_discovery_v2_java_grpc", - ":envoy_service_discovery_v3_java_grpc", - ":envoy_service_load_stats_v2_java_grpc", - ":envoy_service_load_stats_v3_java_grpc", - ":envoy_service_status_v3_java_grpc", + name = "xds_deps_depend", + exports = [ ":orca", - ":xds_protos_java", "//:auto_value_annotations", "//alts", "//api", + "//auth", "//context", "//core:internal", "//netty", @@ -44,9 +33,9 @@ java_library( "//services:metrics_internal", "//stub", "//util", - "@com_google_googleapis//google/rpc:rpc_java_proto", "@com_google_protobuf//:protobuf_java", "@com_google_protobuf//:protobuf_java_util", + "@maven//:com_google_auth_google_auth_library_oauth2_http", artifact("com.google.code.findbugs:jsr305"), artifact("com.google.code.gson:gson"), artifact("com.google.errorprone:error_prone_annotations"), @@ -58,91 +47,89 @@ java_library( artifact("io.netty:netty-handler"), artifact("io.netty:netty-transport"), ], -) - -java_proto_library( - name = "xds_protos_java", - deps = [ - "@com_github_cncf_xds//udpa/type/v1:pkg", - "@com_github_cncf_xds//xds/data/orca/v3:pkg", - "@com_github_cncf_xds//xds/service/orca/v3:pkg", - "@com_github_cncf_xds//xds/type/v3:pkg", - "@envoy_api//envoy/admin/v3:pkg", - "@envoy_api//envoy/api/v2:pkg", - "@envoy_api//envoy/api/v2/core:pkg", - "@envoy_api//envoy/api/v2/endpoint:pkg", - "@envoy_api//envoy/config/cluster/aggregate/v2alpha:pkg", - "@envoy_api//envoy/config/cluster/v3:pkg", - "@envoy_api//envoy/config/core/v3:pkg", - "@envoy_api//envoy/config/endpoint/v3:pkg", - "@envoy_api//envoy/config/filter/http/fault/v2:pkg", - "@envoy_api//envoy/config/filter/http/router/v2:pkg", - "@envoy_api//envoy/config/filter/network/http_connection_manager/v2:pkg", - "@envoy_api//envoy/config/listener/v3:pkg", - "@envoy_api//envoy/config/rbac/v3:pkg", - "@envoy_api//envoy/config/route/v3:pkg", - "@envoy_api//envoy/extensions/clusters/aggregate/v3:pkg", - "@envoy_api//envoy/extensions/filters/common/fault/v3:pkg", - "@envoy_api//envoy/extensions/filters/http/fault/v3:pkg", - "@envoy_api//envoy/extensions/filters/http/rbac/v3:pkg", - "@envoy_api//envoy/extensions/filters/http/router/v3:pkg", - "@envoy_api//envoy/extensions/filters/network/http_connection_manager/v3:pkg", - "@envoy_api//envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3:pkg", - "@envoy_api//envoy/extensions/load_balancing_policies/least_request/v3:pkg", - "@envoy_api//envoy/extensions/load_balancing_policies/pick_first/v3:pkg", - "@envoy_api//envoy/extensions/load_balancing_policies/ring_hash/v3:pkg", - "@envoy_api//envoy/extensions/load_balancing_policies/round_robin/v3:pkg", - "@envoy_api//envoy/extensions/load_balancing_policies/wrr_locality/v3:pkg", - "@envoy_api//envoy/extensions/transport_sockets/tls/v3:pkg", - "@envoy_api//envoy/service/discovery/v2:pkg", - "@envoy_api//envoy/service/discovery/v3:pkg", - "@envoy_api//envoy/service/load_stats/v2:pkg", - "@envoy_api//envoy/service/load_stats/v3:pkg", - "@envoy_api//envoy/service/status/v3:pkg", - "@envoy_api//envoy/type/matcher/v3:pkg", - "@envoy_api//envoy/type/v3:pkg", + runtime_deps = [ + "//compiler:java_grpc_library_deps__do_not_reference", ], ) -java_grpc_library( - name = "envoy_service_discovery_v2_java_grpc", - srcs = ["@envoy_api//envoy/service/discovery/v2:pkg"], - deps = [":xds_protos_java"], +java_library( + name = "xds_deps_depend_neverlink", + neverlink = 1, + exports = [":xds_deps_depend"], ) -java_grpc_library( - name = "envoy_service_discovery_v3_java_grpc", - srcs = ["@envoy_api//envoy/service/discovery/v3:pkg"], - deps = [":xds_protos_java"], +# Deps to be combined into the :xds jar itself +java_library( + name = "xds_deps_embed", + exports = [ + ":envoy_java_grpc", + ":envoy_java_proto", + ":googleapis_rpc_java_proto", + ":xds_java_proto", + ], ) -java_grpc_library( - name = "envoy_service_load_stats_v2_java_grpc", - srcs = ["@envoy_api//envoy/service/load_stats/v2:pkg"], - deps = [":xds_protos_java"], +java_binary( + name = "xds_notjarjar", + srcs = glob( + [ + "src/main/java/**/*.java", + "third_party/zero-allocation-hashing/main/java/**/*.java", + ], + exclude = ["src/main/java/io/grpc/xds/orca/**"], + ), + main_class = "unused", + resources = glob([ + "src/main/resources/**", + ]), + deps = [ + # Do not add additional dependencies here; add them to one of these two deps instead + ":xds_deps_depend_neverlink", + ":xds_deps_embed", + ], ) -java_grpc_library( - name = "envoy_service_load_stats_v3_java_grpc", - srcs = ["@envoy_api//envoy/service/load_stats/v3:pkg"], - deps = [":xds_protos_java"], -) +JAR_JAR_RULES = [ + "zap com.google.protobuf.**", # Drop codegen dep + # Keep in sync with build.gradle's shadowJar + "rule com.github.udpa.** io.grpc.xds.shaded.com.github.udpa.@1", + "rule com.github.xds.** io.grpc.xds.shaded.com.github.xds.@1", + "rule com.google.api.expr.** io.grpc.xds.shaded.com.google.api.expr.@1", + "rule com.google.security.** io.grpc.xds.shaded.com.google.security.@1", + "rule dev.cel.expr.** io.grpc.xds.shaded.dev.cel.expr.@1", + "rule envoy.annotations.** io.grpc.xds.shaded.envoy.annotations.@1", + "rule io.envoyproxy.** io.grpc.xds.shaded.io.envoyproxy.@1", + "rule udpa.annotations.** io.grpc.xds.shaded.udpa.annotations.@1", + "rule xds.annotations.** io.grpc.xds.shaded.xds.annotations.@1", +] -java_grpc_library( - name = "envoy_service_status_v3_java_grpc", - srcs = ["@envoy_api//envoy/service/status/v3:pkg"], - deps = [":xds_protos_java"], +jar_jar( + name = "xds_jarjar", + inline_rules = JAR_JAR_RULES, + input_jar = ":xds_notjarjar_deploy.jar", ) java_library( - name = "orca", - srcs = glob([ - "src/main/java/io/grpc/xds/orca/*.java", - ]), + name = "xds", visibility = ["//visibility:public"], + exports = [":xds_jarjar"], + runtime_deps = [":xds_deps_depend"], +) + +java_proto_library( + name = "googleapis_rpc_java_proto", deps = [ - ":orca_protos_java", - ":xds_service_orca_v3_java_grpc", + "@com_google_googleapis//google/rpc:code_proto", + "@com_google_googleapis//google/rpc:status_proto", + ], +) + +# Ordinary deps for :orca +java_library( + name = "orca_deps_depend", + exports = [ + ":xds_orca_java_grpc", + ":xds_orca_java_proto", "//api", "//context", "//core:internal", @@ -157,16 +144,222 @@ java_library( ], ) +java_library( + name = "orca_deps_depend_neverlink", + neverlink = 1, + exports = [":orca_deps_depend"], +) + +# Deps to be combined into the :orca jar itself +java_library( + name = "orca_deps_embed", + exports = [ + ":xds_orca_java_grpc", + ":xds_orca_java_proto", + ], +) + +java_binary( + name = "orca_notjarjar", + srcs = glob([ + "src/main/java/io/grpc/xds/orca/*.java", + ]), + main_class = "unused", + visibility = ["//visibility:public"], + deps = [ + # Do not add additional dependencies here; add them to one of these two deps instead + ":orca_deps_depend_neverlink", + ":orca_deps_embed", + ], +) + +jar_jar( + name = "orca_jarjar", + inline_rules = JAR_JAR_RULES, + input_jar = ":orca_notjarjar_deploy.jar", +) + +java_library( + name = "orca", + visibility = ["//visibility:public"], + exports = [":orca_jarjar"], + runtime_deps = [":orca_deps_depend"], +) + +java_proto_library( + name = "orca_java_proto", + deps = [":xds_proto"], +) + +java_grpc_library( + name = "orca_java_grpc", + srcs = [":xds_proto"], + deps = [":orca_java_proto"], +) + +proto_library( + name = "cel_spec_proto", + srcs = glob(["third_party/cel-spec/src/main/proto/**/*.proto"]), + strip_import_prefix = "third_party/cel-spec/src/main/proto/", + deps = [ + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:empty_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:timestamp_proto", + ], +) + +proto_library( + name = "envoy_proto", + srcs = glob(["third_party/envoy/src/main/proto/**/*.proto"]), + strip_import_prefix = "third_party/envoy/src/main/proto/", + deps = [ + ":googleapis_proto", + ":protoc_gen_validate_proto", + ":xds_proto", + "@com_google_googleapis//google/api:annotations_proto", + "@com_google_googleapis//google/rpc:status_proto", + "@com_google_protobuf//:any_proto", + "@com_google_protobuf//:descriptor_proto", + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:empty_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:timestamp_proto", + "@com_google_protobuf//:wrappers_proto", + ], +) + java_proto_library( - name = "orca_protos_java", + name = "envoy_java_proto", + deps = [":envoy_proto"], +) + +INTERNAL_java_grpc_library_for_xds( + name = "envoy_java_grpc", + srcs = [":envoy_proto"], + deps = [":envoy_java_proto"], +) + +proto_library( + name = "googleapis_proto", + srcs = glob(["third_party/googleapis/src/main/proto/**/*.proto"]), + strip_import_prefix = "third_party/googleapis/src/main/proto/", deps = [ - "@com_github_cncf_xds//xds/data/orca/v3:pkg", - "@com_github_cncf_xds//xds/service/orca/v3:pkg", + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:empty_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:timestamp_proto", ], ) +proto_library( + name = "protoc_gen_validate_proto", + srcs = glob(["third_party/protoc-gen-validate/src/main/proto/**/*.proto"]), + strip_import_prefix = "third_party/protoc-gen-validate/src/main/proto/", + deps = [ + "@com_google_protobuf//:descriptor_proto", + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:timestamp_proto", + ], +) + +proto_library( + name = "xds_proto", + srcs = glob( + ["third_party/xds/src/main/proto/**/*.proto"], + exclude = [ + "third_party/xds/src/main/proto/xds/data/orca/v3/*.proto", + "third_party/xds/src/main/proto/xds/service/orca/v3/*.proto", + ], + ), + strip_import_prefix = "third_party/xds/src/main/proto/", + deps = [ + ":cel_spec_proto", + ":googleapis_proto", + ":protoc_gen_validate_proto", + "@com_google_protobuf//:any_proto", + "@com_google_protobuf//:descriptor_proto", + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:wrappers_proto", + ], +) + +java_proto_library( + name = "xds_java_proto", + deps = [":xds_proto"], +) + +proto_library( + name = "xds_orca_proto", + srcs = glob([ + "third_party/xds/src/main/proto/xds/data/orca/v3/*.proto", + "third_party/xds/src/main/proto/xds/service/orca/v3/*.proto", + ]), + strip_import_prefix = "third_party/xds/src/main/proto/", + deps = [ + ":protoc_gen_validate_proto", + "@com_google_protobuf//:duration_proto", + ], +) + +java_proto_library( + name = "xds_orca_java_proto", + deps = [":xds_orca_proto"], +) + java_grpc_library( - name = "xds_service_orca_v3_java_grpc", - srcs = ["@com_github_cncf_xds//xds/service/orca/v3:pkg"], - deps = [":orca_protos_java"], + name = "xds_orca_java_grpc", + srcs = [":xds_orca_proto"], + deps = [":xds_orca_java_proto"], +) + +java_rpc_toolchain( + name = "java_grpc_library_toolchain", + plugin = "//compiler:grpc_java_plugin", + runtime = [":java_grpc_library_deps"], +) + +java_library( + name = "java_grpc_library_deps", + neverlink = 1, + exports = ["//compiler:java_grpc_library_deps__do_not_reference"], +) + +java_library( + name = "testlib", + testonly = 1, + srcs = [ + "src/test/java/io/grpc/xds/ControlPlaneRule.java", + "src/test/java/io/grpc/xds/DataPlaneRule.java", + "src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java", + "src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java", + "src/test/java/io/grpc/xds/XdsTestControlPlaneService.java", + "src/test/java/io/grpc/xds/XdsTestLoadReportingService.java", + ], + deps = [ + ":envoy_java_grpc", + ":envoy_java_proto", + ":xds", + ":xds_java_proto", + "//api", + "//api:test_fixtures", + "//core:internal", + "//stub", + "//testing-proto:simpleservice_java_grpc", + "//testing-proto:simpleservice_java_proto", + "//util", + "@com_google_protobuf//java/core", + "@maven//:com_google_code_findbugs_jsr305", + "@maven//:com_google_guava_guava", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + ], +) + +java_test( + name = "FakeControlPlaneXdsIntegrationTest", + size = "small", + test_class = "io.grpc.xds.FakeControlPlaneXdsIntegrationTest", + runtime_deps = [":testlib"], ) diff --git a/xds/build.gradle b/xds/build.gradle index a738145a2a0..8394fe12f6b 100644 --- a/xds/build.gradle +++ b/xds/build.gradle @@ -4,8 +4,8 @@ plugins { id "java" id "maven-publish" - id "com.github.johnrengelman.shadow" id "com.google.protobuf" + id "com.gradleup.shadow" id "ru.vyarus.animalsniffer" } @@ -17,12 +17,11 @@ sourceSets { srcDir "${projectDir}/third_party/zero-allocation-hashing/main/java" } proto { + srcDir 'third_party/cel-spec/src/main/proto' srcDir 'third_party/envoy/src/main/proto' + srcDir 'third_party/googleapis/src/main/proto' srcDir 'third_party/protoc-gen-validate/src/main/proto' srcDir 'third_party/xds/src/main/proto' - srcDir 'third_party/cel-spec/src/main/proto' - srcDir 'third_party/googleapis/src/main/proto' - srcDir 'third_party/istio/src/main/proto' } } main { @@ -42,10 +41,10 @@ configurations { } dependencies { - thirdpartyCompileOnly libraries.javax.annotation thirdpartyImplementation project(':grpc-protobuf'), project(':grpc-stub') compileOnly sourceSets.thirdparty.output + testCompileOnly sourceSets.thirdparty.output implementation project(':grpc-stub'), project(':grpc-core'), project(':grpc-util'), @@ -59,6 +58,7 @@ dependencies { libraries.protobuf.java.util def nettyDependency = implementation project(':grpc-netty') + testImplementation project(':grpc-api') testImplementation project(':grpc-rls') testImplementation project(':grpc-inprocess') testImplementation testFixtures(project(':grpc-core')), @@ -81,7 +81,11 @@ dependencies { shadow configurations.implementation.getDependencies().minus([nettyDependency]) shadow project(path: ':grpc-netty-shaded', configuration: 'shadow') - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } testRuntimeOnly libraries.netty.tcnative, libraries.netty.tcnative.classes testRuntimeOnly (libraries.netty.tcnative) { @@ -127,8 +131,6 @@ tasks.named("checkstyleThirdparty").configure { tasks.named("compileJava").configure { it.options.compilerArgs += [ - // TODO: remove - "-Xlint:-deprecation", // only has AutoValue annotation processor "-Xlint:-processing", ] @@ -182,6 +184,7 @@ tasks.named("shadowJar").configure { include(project(':grpc-xds')) } // Relocated packages commonly need exclusions in jacocoTestReport and javadoc + // Keep in sync with BUILD.bazel's JAR_JAR_RULES relocate 'com.github.udpa', "${prefixName}.shaded.com.github.udpa" relocate 'com.github.xds', "${prefixName}.shaded.com.github.xds" relocate 'com.google.api.expr', "${prefixName}.shaded.com.google.api.expr" diff --git a/xds/src/generated/thirdparty/grpc/com/github/xds/service/orca/v3/OpenRcaServiceGrpc.java b/xds/src/generated/thirdparty/grpc/com/github/xds/service/orca/v3/OpenRcaServiceGrpc.java index de2c7424fca..e0e28ad4072 100644 --- a/xds/src/generated/thirdparty/grpc/com/github/xds/service/orca/v3/OpenRcaServiceGrpc.java +++ b/xds/src/generated/thirdparty/grpc/com/github/xds/service/orca/v3/OpenRcaServiceGrpc.java @@ -14,9 +14,6 @@ * a new call to change backend reporting frequency. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: xds/service/orca/v3/orca.proto") @io.grpc.stub.annotations.GrpcGenerated public final class OpenRcaServiceGrpc { @@ -70,6 +67,21 @@ public OpenRcaServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions c return OpenRcaServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static OpenRcaServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public OpenRcaServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new OpenRcaServiceBlockingV2Stub(channel, callOptions); + } + }; + return OpenRcaServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -191,6 +203,42 @@ public void streamCoreMetrics(com.github.xds.service.orca.v3.OrcaLoadReportReque * a new call to change backend reporting frequency. * */ + public static final class OpenRcaServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private OpenRcaServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected OpenRcaServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new OpenRcaServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamCoreMetrics(com.github.xds.service.orca.v3.OrcaLoadReportRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamCoreMetricsMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service OpenRcaService. + *
+   * Out-of-band (OOB) load reporting service for the additional load reporting
+   * agent that does not sit in the request path. Reports are periodically sampled
+   * with sufficient frequency to provide temporal association with requests.
+   * OOB reporting compensates the limitation of in-band reporting in revealing
+   * costs for backends that do not provide a steady stream of telemetry such as
+   * long running stream operations and zero QPS services. This is a server
+   * streaming service, client needs to terminate current RPC and initiate
+   * a new call to change backend reporting frequency.
+   * 
+ */ public static final class OpenRcaServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private OpenRcaServiceBlockingStub( diff --git a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/auth/v3/AuthorizationGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/auth/v3/AuthorizationGrpc.java new file mode 100644 index 00000000000..df9b7a3514b --- /dev/null +++ b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/auth/v3/AuthorizationGrpc.java @@ -0,0 +1,377 @@ +package io.envoyproxy.envoy.service.auth.v3; + +import static io.grpc.MethodDescriptor.generateFullMethodName; + +/** + *
+ * A generic interface for performing authorization check on incoming
+ * requests to a networked service.
+ * 
+ */ +@io.grpc.stub.annotations.GrpcGenerated +public final class AuthorizationGrpc { + + private AuthorizationGrpc() {} + + public static final java.lang.String SERVICE_NAME = "envoy.service.auth.v3.Authorization"; + + // Static method descriptors that strictly reflect the proto. + private static volatile io.grpc.MethodDescriptor getCheckMethod; + + @io.grpc.stub.annotations.RpcMethod( + fullMethodName = SERVICE_NAME + '/' + "Check", + requestType = io.envoyproxy.envoy.service.auth.v3.CheckRequest.class, + responseType = io.envoyproxy.envoy.service.auth.v3.CheckResponse.class, + methodType = io.grpc.MethodDescriptor.MethodType.UNARY) + public static io.grpc.MethodDescriptor getCheckMethod() { + io.grpc.MethodDescriptor getCheckMethod; + if ((getCheckMethod = AuthorizationGrpc.getCheckMethod) == null) { + synchronized (AuthorizationGrpc.class) { + if ((getCheckMethod = AuthorizationGrpc.getCheckMethod) == null) { + AuthorizationGrpc.getCheckMethod = getCheckMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.UNARY) + .setFullMethodName(generateFullMethodName(SERVICE_NAME, "Check")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.envoyproxy.envoy.service.auth.v3.CheckRequest.getDefaultInstance())) + .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.envoyproxy.envoy.service.auth.v3.CheckResponse.getDefaultInstance())) + .setSchemaDescriptor(new AuthorizationMethodDescriptorSupplier("Check")) + .build(); + } + } + } + return getCheckMethod; + } + + /** + * Creates a new async stub that supports all call types for the service + */ + public static AuthorizationStub newStub(io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AuthorizationStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationStub(channel, callOptions); + } + }; + return AuthorizationStub.newStub(factory, channel); + } + + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static AuthorizationBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AuthorizationBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationBlockingV2Stub(channel, callOptions); + } + }; + return AuthorizationBlockingV2Stub.newStub(factory, channel); + } + + /** + * Creates a new blocking-style stub that supports unary and streaming output calls on the service + */ + public static AuthorizationBlockingStub newBlockingStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AuthorizationBlockingStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationBlockingStub(channel, callOptions); + } + }; + return AuthorizationBlockingStub.newStub(factory, channel); + } + + /** + * Creates a new ListenableFuture-style stub that supports unary calls on the service + */ + public static AuthorizationFutureStub newFutureStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AuthorizationFutureStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationFutureStub(channel, callOptions); + } + }; + return AuthorizationFutureStub.newStub(factory, channel); + } + + /** + *
+   * A generic interface for performing authorization check on incoming
+   * requests to a networked service.
+   * 
+ */ + public interface AsyncService { + + /** + *
+     * Performs authorization check based on the attributes associated with the
+     * incoming request, and returns status `OK` or not `OK`.
+     * 
+ */ + default void check(io.envoyproxy.envoy.service.auth.v3.CheckRequest request, + io.grpc.stub.StreamObserver responseObserver) { + io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getCheckMethod(), responseObserver); + } + } + + /** + * Base class for the server implementation of the service Authorization. + *
+   * A generic interface for performing authorization check on incoming
+   * requests to a networked service.
+   * 
+ */ + public static abstract class AuthorizationImplBase + implements io.grpc.BindableService, AsyncService { + + @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() { + return AuthorizationGrpc.bindService(this); + } + } + + /** + * A stub to allow clients to do asynchronous rpc calls to service Authorization. + *
+   * A generic interface for performing authorization check on incoming
+   * requests to a networked service.
+   * 
+ */ + public static final class AuthorizationStub + extends io.grpc.stub.AbstractAsyncStub { + private AuthorizationStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AuthorizationStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationStub(channel, callOptions); + } + + /** + *
+     * Performs authorization check based on the attributes associated with the
+     * incoming request, and returns status `OK` or not `OK`.
+     * 
+ */ + public void check(io.envoyproxy.envoy.service.auth.v3.CheckRequest request, + io.grpc.stub.StreamObserver responseObserver) { + io.grpc.stub.ClientCalls.asyncUnaryCall( + getChannel().newCall(getCheckMethod(), getCallOptions()), request, responseObserver); + } + } + + /** + * A stub to allow clients to do synchronous rpc calls to service Authorization. + *
+   * A generic interface for performing authorization check on incoming
+   * requests to a networked service.
+   * 
+ */ + public static final class AuthorizationBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private AuthorizationBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AuthorizationBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Performs authorization check based on the attributes associated with the
+     * incoming request, and returns status `OK` or not `OK`.
+     * 
+ */ + public io.envoyproxy.envoy.service.auth.v3.CheckResponse check(io.envoyproxy.envoy.service.auth.v3.CheckRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getCheckMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service Authorization. + *
+   * A generic interface for performing authorization check on incoming
+   * requests to a networked service.
+   * 
+ */ + public static final class AuthorizationBlockingStub + extends io.grpc.stub.AbstractBlockingStub { + private AuthorizationBlockingStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AuthorizationBlockingStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationBlockingStub(channel, callOptions); + } + + /** + *
+     * Performs authorization check based on the attributes associated with the
+     * incoming request, and returns status `OK` or not `OK`.
+     * 
+ */ + public io.envoyproxy.envoy.service.auth.v3.CheckResponse check(io.envoyproxy.envoy.service.auth.v3.CheckRequest request) { + return io.grpc.stub.ClientCalls.blockingUnaryCall( + getChannel(), getCheckMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do ListenableFuture-style rpc calls to service Authorization. + *
+   * A generic interface for performing authorization check on incoming
+   * requests to a networked service.
+   * 
+ */ + public static final class AuthorizationFutureStub + extends io.grpc.stub.AbstractFutureStub { + private AuthorizationFutureStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AuthorizationFutureStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationFutureStub(channel, callOptions); + } + + /** + *
+     * Performs authorization check based on the attributes associated with the
+     * incoming request, and returns status `OK` or not `OK`.
+     * 
+ */ + public com.google.common.util.concurrent.ListenableFuture check( + io.envoyproxy.envoy.service.auth.v3.CheckRequest request) { + return io.grpc.stub.ClientCalls.futureUnaryCall( + getChannel().newCall(getCheckMethod(), getCallOptions()), request); + } + } + + private static final int METHODID_CHECK = 0; + + private static final class MethodHandlers implements + io.grpc.stub.ServerCalls.UnaryMethod, + io.grpc.stub.ServerCalls.ServerStreamingMethod, + io.grpc.stub.ServerCalls.ClientStreamingMethod, + io.grpc.stub.ServerCalls.BidiStreamingMethod { + private final AsyncService serviceImpl; + private final int methodId; + + MethodHandlers(AsyncService serviceImpl, int methodId) { + this.serviceImpl = serviceImpl; + this.methodId = methodId; + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + case METHODID_CHECK: + serviceImpl.check((io.envoyproxy.envoy.service.auth.v3.CheckRequest) request, + (io.grpc.stub.StreamObserver) responseObserver); + break; + default: + throw new AssertionError(); + } + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public io.grpc.stub.StreamObserver invoke( + io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + default: + throw new AssertionError(); + } + } + } + + public static final io.grpc.ServerServiceDefinition bindService(AsyncService service) { + return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) + .addMethod( + getCheckMethod(), + io.grpc.stub.ServerCalls.asyncUnaryCall( + new MethodHandlers< + io.envoyproxy.envoy.service.auth.v3.CheckRequest, + io.envoyproxy.envoy.service.auth.v3.CheckResponse>( + service, METHODID_CHECK))) + .build(); + } + + private static abstract class AuthorizationBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoFileDescriptorSupplier, io.grpc.protobuf.ProtoServiceDescriptorSupplier { + AuthorizationBaseDescriptorSupplier() {} + + @java.lang.Override + public com.google.protobuf.Descriptors.FileDescriptor getFileDescriptor() { + return io.envoyproxy.envoy.service.auth.v3.ExternalAuthProto.getDescriptor(); + } + + @java.lang.Override + public com.google.protobuf.Descriptors.ServiceDescriptor getServiceDescriptor() { + return getFileDescriptor().findServiceByName("Authorization"); + } + } + + private static final class AuthorizationFileDescriptorSupplier + extends AuthorizationBaseDescriptorSupplier { + AuthorizationFileDescriptorSupplier() {} + } + + private static final class AuthorizationMethodDescriptorSupplier + extends AuthorizationBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoMethodDescriptorSupplier { + private final java.lang.String methodName; + + AuthorizationMethodDescriptorSupplier(java.lang.String methodName) { + this.methodName = methodName; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.MethodDescriptor getMethodDescriptor() { + return getServiceDescriptor().findMethodByName(methodName); + } + } + + private static volatile io.grpc.ServiceDescriptor serviceDescriptor; + + public static io.grpc.ServiceDescriptor getServiceDescriptor() { + io.grpc.ServiceDescriptor result = serviceDescriptor; + if (result == null) { + synchronized (AuthorizationGrpc.class) { + result = serviceDescriptor; + if (result == null) { + serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME) + .setSchemaDescriptor(new AuthorizationFileDescriptorSupplier()) + .addMethod(getCheckMethod()) + .build(); + } + } + } + return result; + } +} diff --git a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java index e039c2193e8..94b2fd86b96 100644 --- a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java +++ b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java @@ -12,9 +12,6 @@ * the multiplexed singleton APIs at the Envoy instance and management server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: envoy/service/discovery/v3/ads.proto") @io.grpc.stub.annotations.GrpcGenerated public final class AggregatedDiscoveryServiceGrpc { @@ -99,6 +96,21 @@ public AggregatedDiscoveryServiceStub newStub(io.grpc.Channel channel, io.grpc.C return AggregatedDiscoveryServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static AggregatedDiscoveryServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AggregatedDiscoveryServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AggregatedDiscoveryServiceBlockingV2Stub(channel, callOptions); + } + }; + return AggregatedDiscoveryServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -233,6 +245,52 @@ public io.grpc.stub.StreamObserver */ + public static final class AggregatedDiscoveryServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private AggregatedDiscoveryServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AggregatedDiscoveryServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AggregatedDiscoveryServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * This is a gRPC-only API.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamAggregatedResources() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getStreamAggregatedResourcesMethod(), getCallOptions()); + } + + /** + */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + deltaAggregatedResources() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getDeltaAggregatedResourcesMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service AggregatedDiscoveryService. + *
+   * See https://github.com/envoyproxy/envoy-api#apis for a description of the role of
+   * ADS and how it is intended to be used by a management server. ADS requests
+   * have the same structure as their singleton xDS counterparts, but can
+   * multiplex many resource types on a single stream. The type_url in the
+   * DiscoveryRequest/DiscoveryResponse provides sufficient information to recover
+   * the multiplexed singleton APIs at the Envoy instance and management server.
+   * 
+ */ public static final class AggregatedDiscoveryServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private AggregatedDiscoveryServiceBlockingStub( diff --git a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/load_stats/v3/LoadReportingServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/load_stats/v3/LoadReportingServiceGrpc.java index 2adbf02e98a..4f12405be87 100644 --- a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/load_stats/v3/LoadReportingServiceGrpc.java +++ b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/load_stats/v3/LoadReportingServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: envoy/service/load_stats/v3/lrs.proto") @io.grpc.stub.annotations.GrpcGenerated public final class LoadReportingServiceGrpc { @@ -60,6 +57,21 @@ public LoadReportingServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOpt return LoadReportingServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static LoadReportingServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public LoadReportingServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadReportingServiceBlockingV2Stub(channel, callOptions); + } + }; + return LoadReportingServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -199,6 +211,61 @@ public io.grpc.stub.StreamObserver { + private LoadReportingServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected LoadReportingServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadReportingServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Advanced API to allow for multi-dimensional load balancing by remote
+     * server. For receiving LB assignments, the steps are:
+     * 1, The management server is configured with per cluster/zone/load metric
+     *    capacity configuration. The capacity configuration definition is
+     *    outside of the scope of this document.
+     * 2. Envoy issues a standard {Stream,Fetch}Endpoints request for the clusters
+     *    to balance.
+     * Independently, Envoy will initiate a StreamLoadStats bidi stream with a
+     * management server:
+     * 1. Once a connection establishes, the management server publishes a
+     *    LoadStatsResponse for all clusters it is interested in learning load
+     *    stats about.
+     * 2. For each cluster, Envoy load balances incoming traffic to upstream hosts
+     *    based on per-zone weights and/or per-instance weights (if specified)
+     *    based on intra-zone LbPolicy. This information comes from the above
+     *    {Stream,Fetch}Endpoints.
+     * 3. When upstream hosts reply, they optionally add header <define header
+     *    name> with ASCII representation of EndpointLoadMetricStats.
+     * 4. Envoy aggregates load reports over the period of time given to it in
+     *    LoadStatsResponse.load_reporting_interval. This includes aggregation
+     *    stats Envoy maintains by itself (total_requests, rpc_errors etc.) as
+     *    well as load metrics from upstream hosts.
+     * 5. When the timer of load_reporting_interval expires, Envoy sends new
+     *    LoadStatsRequest filled with load reports for each cluster.
+     * 6. The management server uses the load reports from all reported Envoys
+     *    from around the world, computes global assignment and prepares traffic
+     *    assignment destined for each zone Envoys are located in. Goto 2.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamLoadStats() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getStreamLoadStatsMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service LoadReportingService. + */ public static final class LoadReportingServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private LoadReportingServiceBlockingStub( diff --git a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/rate_limit_quota/v3/RateLimitQuotaServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/rate_limit_quota/v3/RateLimitQuotaServiceGrpc.java index 2cbb7536d4c..3f17bb54566 100644 --- a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/rate_limit_quota/v3/RateLimitQuotaServiceGrpc.java +++ b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/rate_limit_quota/v3/RateLimitQuotaServiceGrpc.java @@ -7,9 +7,6 @@ * Defines the Rate Limit Quota Service (RLQS). * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: envoy/service/rate_limit_quota/v3/rlqs.proto") @io.grpc.stub.annotations.GrpcGenerated public final class RateLimitQuotaServiceGrpc { @@ -63,6 +60,21 @@ public RateLimitQuotaServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOp return RateLimitQuotaServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static RateLimitQuotaServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public RateLimitQuotaServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RateLimitQuotaServiceBlockingV2Stub(channel, callOptions); + } + }; + return RateLimitQuotaServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -164,6 +176,39 @@ public io.grpc.stub.StreamObserver */ + public static final class RateLimitQuotaServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private RateLimitQuotaServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected RateLimitQuotaServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RateLimitQuotaServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Main communication channel: the data plane sends usage reports to the RLQS server,
+     * and the server asynchronously responding with the assignments.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamRateLimitQuotas() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getStreamRateLimitQuotasMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service RateLimitQuotaService. + *
+   * Defines the Rate Limit Quota Service (RLQS).
+   * 
+ */ public static final class RateLimitQuotaServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private RateLimitQuotaServiceBlockingStub( diff --git a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/status/v3/ClientStatusDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/status/v3/ClientStatusDiscoveryServiceGrpc.java index 3f8874248d0..cb166503566 100644 --- a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/status/v3/ClientStatusDiscoveryServiceGrpc.java +++ b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/status/v3/ClientStatusDiscoveryServiceGrpc.java @@ -9,9 +9,6 @@ * also be used to get the current xDS states directly from the client. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: envoy/service/status/v3/csds.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ClientStatusDiscoveryServiceGrpc { @@ -96,6 +93,21 @@ public ClientStatusDiscoveryServiceStub newStub(io.grpc.Channel channel, io.grpc return ClientStatusDiscoveryServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ClientStatusDiscoveryServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ClientStatusDiscoveryServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ClientStatusDiscoveryServiceBlockingV2Stub(channel, callOptions); + } + }; + return ClientStatusDiscoveryServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -212,6 +224,44 @@ public void fetchClientStatus(io.envoyproxy.envoy.service.status.v3.ClientStatus * also be used to get the current xDS states directly from the client. * */ + public static final class ClientStatusDiscoveryServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ClientStatusDiscoveryServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ClientStatusDiscoveryServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ClientStatusDiscoveryServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamClientStatus() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getStreamClientStatusMethod(), getCallOptions()); + } + + /** + */ + public io.envoyproxy.envoy.service.status.v3.ClientStatusResponse fetchClientStatus(io.envoyproxy.envoy.service.status.v3.ClientStatusRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getFetchClientStatusMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ClientStatusDiscoveryService. + *
+   * CSDS is Client Status Discovery Service. It can be used to get the status of
+   * an xDS-compliant client from the management server's point of view. It can
+   * also be used to get the current xDS states directly from the client.
+   * 
+ */ public static final class ClientStatusDiscoveryServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ClientStatusDiscoveryServiceBlockingStub( diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index 773fdf20563..5a59b47c529 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -18,39 +18,57 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.xds.XdsLbPolicies.CLUSTER_RESOLVER_POLICY_NAME; - -import com.google.common.annotations.VisibleForTesting; +import static io.grpc.xds.XdsLbPolicies.CDS_POLICY_NAME; +import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; + +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.UnsignedInts; +import com.google.errorprone.annotations.CheckReturnValue; +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; +import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver; import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.internal.ObjectPool; +import io.grpc.StatusOr; +import io.grpc.internal.GrpcUtil; import io.grpc.util.GracefulSwitchLoadBalancer; +import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig.DiscoveryMechanism; +import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; +import io.grpc.xds.Endpoints.DropOverload; +import io.grpc.xds.Endpoints.LbEndpoint; +import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.EnvoyServerProtoData.FailurePercentageEjection; +import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; +import io.grpc.xds.EnvoyServerProtoData.SuccessRateEjection; +import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig; +import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig; import io.grpc.xds.XdsClusterResource.CdsUpdate; import io.grpc.xds.XdsClusterResource.CdsUpdate.ClusterType; -import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsClient.ResourceWatcher; +import io.grpc.xds.XdsConfig.Subscription; +import io.grpc.xds.XdsConfig.XdsClusterConfig; +import io.grpc.xds.XdsConfig.XdsClusterConfig.AggregateConfig; +import io.grpc.xds.XdsConfig.XdsClusterConfig.EndpointConfig; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.client.Locality; import io.grpc.xds.client.XdsLogger; import io.grpc.xds.client.XdsLogger.XdsLogLevel; -import java.util.ArrayDeque; +import io.grpc.xds.internal.XdsInternalAttributes; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Queue; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import javax.annotation.Nullable; +import java.util.TreeMap; /** * Load balancer for cds_experimental LB policy. One instance per top-level cluster. @@ -58,50 +76,128 @@ * by a group of sub-clusters in a tree hierarchy. */ final class CdsLoadBalancer2 extends LoadBalancer { + static boolean pickFirstWeightedShuffling = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true); + private final XdsLogger logger; private final Helper helper; - private final SynchronizationContext syncContext; private final LoadBalancerRegistry lbRegistry; + private final ClusterState clusterState = new ClusterState(); + private GracefulSwitchLoadBalancer delegate; // Following fields are effectively final. - private ObjectPool xdsClientPool; - private XdsClient xdsClient; - private CdsLbState cdsLbState; - private ResolvedAddresses resolvedAddresses; - - CdsLoadBalancer2(Helper helper) { - this(helper, LoadBalancerRegistry.getDefaultRegistry()); - } + private String clusterName; + private Subscription clusterSubscription; - @VisibleForTesting CdsLoadBalancer2(Helper helper, LoadBalancerRegistry lbRegistry) { this.helper = checkNotNull(helper, "helper"); - this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); this.lbRegistry = checkNotNull(lbRegistry, "lbRegistry"); + this.delegate = new GracefulSwitchLoadBalancer(helper); logger = XdsLogger.withLogId(InternalLogId.allocate("cds-lb", helper.getAuthority())); logger.log(XdsLogLevel.INFO, "Created"); } @Override public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - if (this.resolvedAddresses != null) { + logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); + if (this.clusterName == null) { + CdsConfig config = (CdsConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + logger.log(XdsLogLevel.INFO, "Config: {0}", config); + if (config.isDynamic) { + clusterSubscription = resolvedAddresses.getAttributes() + .get(XdsAttributes.XDS_CLUSTER_SUBSCRIPT_REGISTRY) + .subscribeToCluster(config.name); + } + this.clusterName = config.name; + } + XdsConfig xdsConfig = resolvedAddresses.getAttributes().get(XdsAttributes.XDS_CONFIG); + StatusOr clusterConfigOr = xdsConfig.getClusters().get(clusterName); + if (clusterConfigOr == null) { + if (clusterSubscription == null) { + // Should be impossible, because XdsDependencyManager wouldn't have generated this + return fail(Status.INTERNAL.withDescription( + errorPrefix() + "Unable to find non-dynamic cluster")); + } + // The dynamic cluster must not have loaded yet return Status.OK; } - logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); - this.resolvedAddresses = resolvedAddresses; - xdsClientPool = resolvedAddresses.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL); - xdsClient = xdsClientPool.getObject(); - CdsConfig config = (CdsConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - logger.log(XdsLogLevel.INFO, "Config: {0}", config); - cdsLbState = new CdsLbState(config.name); - cdsLbState.start(); - return Status.OK; + if (!clusterConfigOr.hasValue()) { + return fail(clusterConfigOr.getStatus()); + } + XdsClusterConfig clusterConfig = clusterConfigOr.getValue(); + + NameResolver.ConfigOrError configOrError; + if (clusterConfig.getChildren() instanceof EndpointConfig) { + // The LB policy config is provided in service_config.proto/JSON format. + configOrError = + GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( + Arrays.asList(clusterConfig.getClusterResource().lbPolicyConfig()), + lbRegistry); + if (configOrError.getError() != null) { + // Should be impossible, because XdsClusterResource validated this + return fail(Status.INTERNAL.withDescription( + errorPrefix() + "Unable to parse the LB config: " + configOrError.getError())); + } + + StatusOr edsUpdate = getEdsUpdate(xdsConfig, clusterName); + StatusOr statusOrResult = clusterState.edsUpdateToResult( + clusterName, + clusterConfig.getClusterResource(), + configOrError.getConfig(), + edsUpdate); + if (!statusOrResult.hasValue()) { + Status status = Status.UNAVAILABLE + .withDescription(statusOrResult.getStatus().getDescription()) + .withCause(statusOrResult.getStatus().getCause()); + delegate.handleNameResolutionError(status); + return status; + } + ClusterResolutionResult result = statusOrResult.getValue(); + List addresses = result.addresses; + if (addresses.isEmpty()) { + Status status = Status.UNAVAILABLE + .withDescription("No usable endpoint from cluster: " + clusterName); + delegate.handleNameResolutionError(status); + return status; + } + Object gracefulConfig = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + lbRegistry.getProvider(PRIORITY_POLICY_NAME), + new PriorityLbConfig( + Collections.unmodifiableMap(result.priorityChildConfigs), + Collections.unmodifiableList(result.priorities))); + return delegate.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(gracefulConfig) + .setAddresses(Collections.unmodifiableList(addresses)) + .build()); + } else if (clusterConfig.getChildren() instanceof AggregateConfig) { + Map priorityChildConfigs = new HashMap<>(); + List leafClusters = ((AggregateConfig) clusterConfig.getChildren()).getLeafNames(); + for (String childCluster: leafClusters) { + priorityChildConfigs.put(childCluster, + new PriorityChildConfig( + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + lbRegistry.getProvider(CDS_POLICY_NAME), + new CdsConfig(childCluster)), + false)); + } + Object gracefulConfig = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + lbRegistry.getProvider(PRIORITY_POLICY_NAME), + new PriorityLoadBalancerProvider.PriorityLbConfig( + Collections.unmodifiableMap(priorityChildConfigs), leafClusters)); + return delegate.acceptResolvedAddresses( + resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(gracefulConfig).build()); + } else { + return fail(Status.INTERNAL.withDescription( + errorPrefix() + "Unexpected cluster children type: " + + clusterConfig.getChildren().getClass())); + } } @Override public void handleNameResolutionError(Status error) { logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error); - if (cdsLbState != null && cdsLbState.childLb != null) { - cdsLbState.childLb.handleNameResolutionError(error); + if (delegate != null) { + delegate.handleNameResolutionError(error); } else { helper.updateBalancingState( TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); @@ -111,308 +207,409 @@ public void handleNameResolutionError(Status error) { @Override public void shutdown() { logger.log(XdsLogLevel.INFO, "Shutdown"); - if (cdsLbState != null) { - cdsLbState.shutdown(); - } - if (xdsClientPool != null) { - xdsClientPool.returnObject(xdsClient); + delegate.shutdown(); + delegate = new GracefulSwitchLoadBalancer(helper); + if (clusterSubscription != null) { + clusterSubscription.close(); + clusterSubscription = null; } } + @CheckReturnValue // don't forget to return up the stack after the fail call + private Status fail(Status error) { + delegate.shutdown(); + helper.updateBalancingState( + TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); + return Status.OK; // XdsNameResolver isn't a polling NR, so this value doesn't matter + } + + private String errorPrefix() { + return "CdsLb for " + clusterName + ": "; + } + /** - * The state of a CDS working session of {@link CdsLoadBalancer2}. Created and started when - * receiving the CDS LB policy config with the top-level cluster name. + * The number of bits assigned to the fractional part of fixed-point values. We normalize weights + * to a fixed-point number between 0 and 1, representing that item's proportion of traffic (1 == + * 100% of traffic). We reserve at least one bit for the whole number so that we don't need to + * special case a single item, and so that we can round up very low values without risking uint32 + * overflow of the sum of weights. */ - private final class CdsLbState { + private static final int FIXED_POINT_FRACTIONAL_BITS = 31; - private final ClusterState root; - private final Map clusterStates = new ConcurrentHashMap<>(); - private LoadBalancer childLb; + /** Divide two uint32s and produce a fixed-point uint32 result. */ + private static long fractionToFixedPoint(long numerator, long denominator) { + long one = 1L << FIXED_POINT_FRACTIONAL_BITS; + return numerator * one / denominator; + } - private CdsLbState(String rootCluster) { - root = new ClusterState(rootCluster); - } + /** Multiply two uint32 fixed-point numbers, returning a uint32 fixed-point. */ + private static long fixedPointMultiply(long a, long b) { + return (a * b) >> FIXED_POINT_FRACTIONAL_BITS; + } - private void start() { - root.start(); + private static StatusOr getEdsUpdate(XdsConfig xdsConfig, String cluster) { + StatusOr clusterConfig = xdsConfig.getClusters().get(cluster); + if (clusterConfig == null) { + return StatusOr.fromStatus(Status.INTERNAL + .withDescription("BUG: cluster resolver could not find cluster in xdsConfig")); } - - private void shutdown() { - root.shutdown(); - if (childLb != null) { - childLb.shutdown(); - } + if (!clusterConfig.hasValue()) { + return StatusOr.fromStatus(clusterConfig.getStatus()); + } + if (!(clusterConfig.getValue().getChildren() instanceof XdsClusterConfig.EndpointConfig)) { + return StatusOr.fromStatus(Status.INTERNAL + .withDescription("BUG: cluster resolver cluster with children of unknown type")); } + XdsClusterConfig.EndpointConfig endpointConfig = + (XdsClusterConfig.EndpointConfig) clusterConfig.getValue().getChildren(); + return endpointConfig.getEndpoint(); + } + + /** + * Generates a string that represents the priority in the LB policy config. The string is unique + * across priorities in all clusters and priorityName(c, p1) < priorityName(c, p2) iff p1 < p2. + * The ordering is undefined for priorities in different clusters. + */ + private static String priorityName(String cluster, int priority) { + return cluster + "[child" + priority + "]"; + } - private void handleClusterDiscovered() { - List instances = new ArrayList<>(); - - // Used for loop detection to break the infinite recursion that loops would cause - Map> parentClusters = new HashMap<>(); - Status loopStatus = null; - - // Level-order traversal. - // Collect configurations for all non-aggregate (leaf) clusters. - Queue queue = new ArrayDeque<>(); - queue.add(root); - while (!queue.isEmpty()) { - int size = queue.size(); - for (int i = 0; i < size; i++) { - ClusterState clusterState = queue.remove(); - if (!clusterState.discovered) { - return; // do not proceed until all clusters discovered + /** + * Generates a string that represents the locality in the LB policy config. The string is unique + * across all localities in all clusters. + */ + private static String localityName(Locality locality) { + return "{region=\"" + locality.region() + + "\", zone=\"" + locality.zone() + + "\", sub_zone=\"" + locality.subZone() + + "\"}"; + } + + private final class ClusterState { + private Map localityPriorityNames = Collections.emptyMap(); + int priorityNameGenId = 1; + + StatusOr edsUpdateToResult( + String clusterName, + CdsUpdate discovery, + Object lbConfig, + StatusOr updateOr) { + if (!updateOr.hasValue()) { + return StatusOr.fromStatus(updateOr.getStatus()); + } + EdsUpdate update = updateOr.getValue(); + logger.log(XdsLogLevel.DEBUG, "Received endpoint update {0}", update); + if (logger.isLoggable(XdsLogLevel.INFO)) { + logger.log(XdsLogLevel.INFO, "Cluster {0}: {1} localities, {2} drop categories", + clusterName, update.localityLbEndpointsMap.size(), + update.dropPolicies.size()); + } + Map localityLbEndpoints = + update.localityLbEndpointsMap; + List dropOverloads = update.dropPolicies; + List addresses = new ArrayList<>(); + Map> prioritizedLocalityWeights = new HashMap<>(); + List sortedPriorityNames = + generatePriorityNames(clusterName, localityLbEndpoints); + Map priorityLocalityWeightSums; + if (pickFirstWeightedShuffling) { + priorityLocalityWeightSums = new HashMap<>(sortedPriorityNames.size() * 2); + for (Locality locality : localityLbEndpoints.keySet()) { + LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); + String priorityName = localityPriorityNames.get(locality); + Long sum = priorityLocalityWeightSums.get(priorityName); + if (sum == null) { + sum = 0L; } - if (clusterState.result == null) { // resource revoked or not exists - continue; + long weight = UnsignedInts.toLong(localityLbInfo.localityWeight()); + priorityLocalityWeightSums.put(priorityName, sum + weight); + } + } else { + priorityLocalityWeightSums = null; + } + + for (Locality locality : localityLbEndpoints.keySet()) { + LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); + String priorityName = localityPriorityNames.get(locality); + boolean discard = true; + // These sums _should_ fit in uint32, but XdsEndpointResource isn't actually verifying that + // is true today. Since we are using long to avoid signedness trouble, the math happens to + // still work if it turns out the sums exceed uint32. + long localityWeightSum = 0; + long endpointWeightSum = 0; + if (pickFirstWeightedShuffling) { + localityWeightSum = priorityLocalityWeightSums.get(priorityName); + for (LbEndpoint endpoint : localityLbInfo.endpoints()) { + if (endpoint.isHealthy()) { + endpointWeightSum += UnsignedInts.toLong(endpoint.loadBalancingWeight()); + } } - if (clusterState.isLeaf) { - if (instances.stream().map(inst -> inst.cluster).noneMatch(clusterState.name::equals)) { - DiscoveryMechanism instance; - if (clusterState.result.clusterType() == ClusterType.EDS) { - instance = DiscoveryMechanism.forEds( - clusterState.name, clusterState.result.edsServiceName(), - clusterState.result.lrsServerInfo(), - clusterState.result.maxConcurrentRequests(), - clusterState.result.upstreamTlsContext(), - clusterState.result.filterMetadata(), - clusterState.result.outlierDetection()); - } else { // logical DNS - instance = DiscoveryMechanism.forLogicalDns( - clusterState.name, clusterState.result.dnsHostName(), - clusterState.result.lrsServerInfo(), - clusterState.result.maxConcurrentRequests(), - clusterState.result.upstreamTlsContext(), - clusterState.result.filterMetadata()); + } + for (LbEndpoint endpoint : localityLbInfo.endpoints()) { + if (endpoint.isHealthy()) { + discard = false; + long weight; + if (pickFirstWeightedShuffling) { + // Combine locality and endpoint weights as defined by gRFC A113 + long localityWeight = fractionToFixedPoint( + UnsignedInts.toLong(localityLbInfo.localityWeight()), localityWeightSum); + long endpointWeight = fractionToFixedPoint( + UnsignedInts.toLong(endpoint.loadBalancingWeight()), endpointWeightSum); + weight = fixedPointMultiply(localityWeight, endpointWeight); + if (weight == 0) { + weight = 1; } - instances.add(instance); - } - } else { - if (clusterState.childClusterStates == null) { - continue; - } - // Do loop detection and break recursion if detected - List namesCausingLoops = identifyLoops(clusterState, parentClusters); - if (namesCausingLoops.isEmpty()) { - queue.addAll(clusterState.childClusterStates.values()); } else { - // Do cleanup - if (childLb != null) { - childLb.shutdown(); - childLb = null; + weight = localityLbInfo.localityWeight(); + if (endpoint.loadBalancingWeight() != 0) { + weight *= endpoint.loadBalancingWeight(); } - if (loopStatus != null) { - logger.log(XdsLogLevel.WARNING, - "Multiple loops in CDS config. Old msg: " + loopStatus.getDescription()); + } + + String localityName = localityName(locality); + Attributes attr = + endpoint.eag().getAttributes().toBuilder() + .set(io.grpc.xds.XdsAttributes.ATTR_LOCALITY, locality) + .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, localityName) + .set(io.grpc.xds.XdsAttributes.ATTR_LOCALITY_WEIGHT, + localityLbInfo.localityWeight()) + .set(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT, weight) + .set(XdsInternalAttributes.ATTR_ADDRESS_NAME, endpoint.hostname()) + .build(); + EquivalentAddressGroup eag; + if (discovery.isHttp11ProxyAvailable()) { + List rewrittenAddresses = new ArrayList<>(); + for (SocketAddress addr : endpoint.eag().getAddresses()) { + rewrittenAddresses.add(rewriteAddress( + addr, endpoint.endpointMetadata(), localityLbInfo.localityMetadata())); } - loopStatus = Status.UNAVAILABLE.withDescription(String.format( - "CDS error: circular aggregate clusters directly under %s for " - + "root cluster %s, named %s", - clusterState.name, root.name, namesCausingLoops)); + eag = new EquivalentAddressGroup(rewrittenAddresses, attr); + } else { + eag = new EquivalentAddressGroup(endpoint.eag().getAddresses(), attr); } + eag = AddressFilter.setPathFilter(eag, Arrays.asList(priorityName, localityName)); + addresses.add(eag); } } + if (discard) { + logger.log(XdsLogLevel.INFO, + "Discard locality {0} with 0 healthy endpoints", locality); + continue; + } + if (!prioritizedLocalityWeights.containsKey(priorityName)) { + prioritizedLocalityWeights.put(priorityName, new HashMap()); + } + prioritizedLocalityWeights.get(priorityName).put( + locality, localityLbInfo.localityWeight()); } + if (prioritizedLocalityWeights.isEmpty()) { + // Will still update the result, as if the cluster resource is revoked. + logger.log(XdsLogLevel.INFO, + "Cluster {0} has no usable priority/locality/endpoint", clusterName); + } + sortedPriorityNames.retainAll(prioritizedLocalityWeights.keySet()); + Map priorityChildConfigs = + generatePriorityChildConfigs( + clusterName, discovery, lbConfig, lbRegistry, + prioritizedLocalityWeights, dropOverloads); + return StatusOr.fromValue(new ClusterResolutionResult(addresses, priorityChildConfigs, + sortedPriorityNames)); + } - if (loopStatus != null) { - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(loopStatus))); - return; + private SocketAddress rewriteAddress(SocketAddress addr, + ImmutableMap endpointMetadata, + ImmutableMap localityMetadata) { + if (!(addr instanceof InetSocketAddress)) { + return addr; } - if (instances.isEmpty()) { // none of non-aggregate clusters exists - if (childLb != null) { - childLb.shutdown(); - childLb = null; + SocketAddress proxyAddress; + try { + proxyAddress = (SocketAddress) endpointMetadata.get( + "envoy.http11_proxy_transport_socket.proxy_address"); + if (proxyAddress == null) { + proxyAddress = (SocketAddress) localityMetadata.get( + "envoy.http11_proxy_transport_socket.proxy_address"); } - Status unavailable = - Status.UNAVAILABLE.withDescription("CDS error: found 0 leaf (logical DNS or EDS) " - + "clusters for root cluster " + root.name); - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(unavailable))); - return; + } catch (ClassCastException e) { + return addr; } - // The LB policy config is provided in service_config.proto/JSON format. - NameResolver.ConfigOrError configOrError = - GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( - Arrays.asList(root.result.lbPolicyConfig()), lbRegistry); - if (configOrError.getError() != null) { - throw configOrError.getError().augmentDescription("Unable to parse the LB config") - .asRuntimeException(); + if (proxyAddress == null) { + return addr; } - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.unmodifiableList(instances), configOrError.getConfig()); - if (childLb == null) { - childLb = lbRegistry.getProvider(CLUSTER_RESOLVER_POLICY_NAME).newLoadBalancer(helper); - } - childLb.handleResolvedAddresses( - resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(config).build()); + return HttpConnectProxiedSocketAddress.newBuilder() + .setTargetAddress((InetSocketAddress) addr) + .setProxyAddress(proxyAddress) + .build(); } - /** - * Returns children that would cause loops and builds up the parentClusters map. - **/ - - private List identifyLoops(ClusterState clusterState, - Map> parentClusters) { - Set ancestors = new HashSet<>(); - ancestors.add(clusterState.name); - addAncestors(ancestors, clusterState, parentClusters); - - List namesCausingLoops = new ArrayList<>(); - for (ClusterState state : clusterState.childClusterStates.values()) { - if (ancestors.contains(state.name)) { - namesCausingLoops.add(state.name); + private List generatePriorityNames(String name, + Map localityLbEndpoints) { + TreeMap> todo = new TreeMap<>(); + for (Locality locality : localityLbEndpoints.keySet()) { + int priority = localityLbEndpoints.get(locality).priority(); + if (!todo.containsKey(priority)) { + todo.put(priority, new ArrayList<>()); } + todo.get(priority).add(locality); } + Map newNames = new HashMap<>(); + Set usedNames = new HashSet<>(); + List ret = new ArrayList<>(); + for (Integer priority: todo.keySet()) { + String foundName = ""; + for (Locality locality : todo.get(priority)) { + if (localityPriorityNames.containsKey(locality) + && usedNames.add(localityPriorityNames.get(locality))) { + foundName = localityPriorityNames.get(locality); + break; + } + } + if ("".equals(foundName)) { + foundName = priorityName(name, priorityNameGenId++); + } + for (Locality locality : todo.get(priority)) { + newNames.put(locality, foundName); + } + ret.add(foundName); + } + localityPriorityNames = newNames; + return ret; + } + } - // Update parent map with entries from remaining children to clusterState - clusterState.childClusterStates.values().stream() - .filter(child -> !namesCausingLoops.contains(child.name)) - .forEach( - child -> parentClusters.computeIfAbsent(child, k -> new ArrayList<>()) - .add(clusterState)); - - return namesCausingLoops; + private static class ClusterResolutionResult { + // Endpoint addresses. + private final List addresses; + // Config (include load balancing policy/config) for each priority in the cluster. + private final Map priorityChildConfigs; + // List of priority names ordered in descending priorities. + private final List priorities; + + ClusterResolutionResult(List addresses, + Map configs, List priorities) { + this.addresses = addresses; + this.priorityChildConfigs = configs; + this.priorities = priorities; } + } - /** Recursively add all parents to the ancestors list. **/ - private void addAncestors(Set ancestors, ClusterState clusterState, - Map> parentClusters) { - List directParents = parentClusters.get(clusterState); - if (directParents != null) { - directParents.stream().map(c -> c.name).forEach(ancestors::add); - directParents.forEach(p -> addAncestors(ancestors, p, parentClusters)); + /** + * Generates configs to be used in the priority LB policy for priorities in a cluster. + * + *

priority LB -> cluster_impl LB (one per priority) -> (weighted_target LB + * -> round_robin / least_request_experimental (one per locality)) / ring_hash_experimental + */ + private static Map generatePriorityChildConfigs( + String clusterName, + CdsUpdate discovery, + Object endpointLbConfig, + LoadBalancerRegistry lbRegistry, + Map> prioritizedLocalityWeights, + List dropOverloads) { + Map configs = new HashMap<>(); + for (String priority : prioritizedLocalityWeights.keySet()) { + ClusterImplConfig clusterImplConfig = + new ClusterImplConfig( + clusterName, discovery.edsServiceName(), discovery.lrsServerInfo(), + discovery.maxConcurrentRequests(), dropOverloads, endpointLbConfig, + discovery.upstreamTlsContext(), discovery.filterMetadata(), + discovery.backendMetricPropagation()); + LoadBalancerProvider clusterImplLbProvider = + lbRegistry.getProvider(XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME); + Object priorityChildPolicy = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + clusterImplLbProvider, clusterImplConfig); + + // If outlier detection has been configured we wrap the child policy in the outlier detection + // load balancer. + if (discovery.outlierDetection() != null) { + LoadBalancerProvider outlierDetectionProvider = lbRegistry.getProvider( + "outlier_detection_experimental"); + priorityChildPolicy = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + outlierDetectionProvider, + buildOutlierDetectionLbConfig(discovery.outlierDetection(), priorityChildPolicy)); } + + boolean isEds = discovery.clusterType() == ClusterType.EDS; + PriorityChildConfig priorityChildConfig = + new PriorityChildConfig(priorityChildPolicy, isEds /* ignoreReresolution */); + configs.put(priority, priorityChildConfig); } + return configs; + } - private void handleClusterDiscoveryError(Status error) { - if (childLb != null) { - childLb.handleNameResolutionError(error); - } else { - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); - } + /** + * Converts {@link OutlierDetection} that represents the xDS configuration to {@link + * OutlierDetectionLoadBalancerConfig} that the {@link io.grpc.util.OutlierDetectionLoadBalancer} + * understands. + */ + private static OutlierDetectionLoadBalancerConfig buildOutlierDetectionLbConfig( + OutlierDetection outlierDetection, Object childConfig) { + OutlierDetectionLoadBalancerConfig.Builder configBuilder + = new OutlierDetectionLoadBalancerConfig.Builder(); + + configBuilder.setChildConfig(childConfig); + + if (outlierDetection.intervalNanos() != null) { + configBuilder.setIntervalNanos(outlierDetection.intervalNanos()); + } + if (outlierDetection.baseEjectionTimeNanos() != null) { + configBuilder.setBaseEjectionTimeNanos(outlierDetection.baseEjectionTimeNanos()); + } + if (outlierDetection.maxEjectionTimeNanos() != null) { + configBuilder.setMaxEjectionTimeNanos(outlierDetection.maxEjectionTimeNanos()); + } + if (outlierDetection.maxEjectionPercent() != null) { + configBuilder.setMaxEjectionPercent(outlierDetection.maxEjectionPercent()); } - private final class ClusterState implements ResourceWatcher { - private final String name; - @Nullable - private Map childClusterStates; - @Nullable - private CdsUpdate result; - // Following fields are effectively final. - private boolean isLeaf; - private boolean discovered; - private boolean shutdown; - - private ClusterState(String name) { - this.name = name; - } + SuccessRateEjection successRate = outlierDetection.successRateEjection(); + if (successRate != null) { + OutlierDetectionLoadBalancerConfig.SuccessRateEjection.Builder + successRateConfigBuilder = new OutlierDetectionLoadBalancerConfig + .SuccessRateEjection.Builder(); - private void start() { - shutdown = false; - xdsClient.watchXdsResource(XdsClusterResource.getInstance(), name, this, syncContext); + if (successRate.stdevFactor() != null) { + successRateConfigBuilder.setStdevFactor(successRate.stdevFactor()); } - - void shutdown() { - shutdown = true; - xdsClient.cancelXdsResourceWatch(XdsClusterResource.getInstance(), name, this); - if (childClusterStates != null) { - // recursively shut down all descendants - childClusterStates.values().stream() - .filter(state -> !state.shutdown) - .forEach(ClusterState::shutdown); - } + if (successRate.enforcementPercentage() != null) { + successRateConfigBuilder.setEnforcementPercentage(successRate.enforcementPercentage()); } - - @Override - public void onError(Status error) { - Status status = Status.UNAVAILABLE - .withDescription( - String.format("Unable to load CDS %s. xDS server returned: %s: %s", - name, error.getCode(), error.getDescription())) - .withCause(error.getCause()); - if (shutdown) { - return; - } - // All watchers should receive the same error, so we only propagate it once. - if (ClusterState.this == root) { - handleClusterDiscoveryError(status); - } + if (successRate.minimumHosts() != null) { + successRateConfigBuilder.setMinimumHosts(successRate.minimumHosts()); } - - @Override - public void onResourceDoesNotExist(String resourceName) { - if (shutdown) { - return; - } - discovered = true; - result = null; - if (childClusterStates != null) { - for (ClusterState state : childClusterStates.values()) { - state.shutdown(); - } - childClusterStates = null; - } - handleClusterDiscovered(); + if (successRate.requestVolume() != null) { + successRateConfigBuilder.setRequestVolume(successRate.requestVolume()); } - @Override - public void onChanged(final CdsUpdate update) { - if (shutdown) { - return; - } - logger.log(XdsLogLevel.DEBUG, "Received cluster update {0}", update); - discovered = true; - result = update; - if (update.clusterType() == ClusterType.AGGREGATE) { - isLeaf = false; - logger.log(XdsLogLevel.INFO, "Aggregate cluster {0}, underlying clusters: {1}", - update.clusterName(), update.prioritizedClusterNames()); - Map newChildStates = new LinkedHashMap<>(); - for (String cluster : update.prioritizedClusterNames()) { - if (newChildStates.containsKey(cluster)) { - logger.log(XdsLogLevel.WARNING, - String.format("duplicate cluster name %s in aggregate %s is being ignored", - cluster, update.clusterName())); - continue; - } - if (childClusterStates == null || !childClusterStates.containsKey(cluster)) { - ClusterState childState; - if (clusterStates.containsKey(cluster)) { - childState = clusterStates.get(cluster); - if (childState.shutdown) { - childState.start(); - } - } else { - childState = new ClusterState(cluster); - clusterStates.put(cluster, childState); - childState.start(); - } - newChildStates.put(cluster, childState); - } else { - newChildStates.put(cluster, childClusterStates.remove(cluster)); - } - } - if (childClusterStates != null) { // stop subscribing to revoked child clusters - for (ClusterState watcher : childClusterStates.values()) { - watcher.shutdown(); - } - } - childClusterStates = newChildStates; - } else if (update.clusterType() == ClusterType.EDS) { - isLeaf = true; - logger.log(XdsLogLevel.INFO, "EDS cluster {0}, edsServiceName: {1}", - update.clusterName(), update.edsServiceName()); - } else { // logical DNS - isLeaf = true; - logger.log(XdsLogLevel.INFO, "Logical DNS cluster {0}", update.clusterName()); - } - handleClusterDiscovered(); + configBuilder.setSuccessRateEjection(successRateConfigBuilder.build()); + } + + FailurePercentageEjection failurePercentage = outlierDetection.failurePercentageEjection(); + if (failurePercentage != null) { + OutlierDetectionLoadBalancerConfig.FailurePercentageEjection.Builder + failurePercentageConfigBuilder = new OutlierDetectionLoadBalancerConfig + .FailurePercentageEjection.Builder(); + + if (failurePercentage.threshold() != null) { + failurePercentageConfigBuilder.setThreshold(failurePercentage.threshold()); + } + if (failurePercentage.enforcementPercentage() != null) { + failurePercentageConfigBuilder.setEnforcementPercentage( + failurePercentage.enforcementPercentage()); + } + if (failurePercentage.minimumHosts() != null) { + failurePercentageConfigBuilder.setMinimumHosts(failurePercentage.minimumHosts()); + } + if (failurePercentage.requestVolume() != null) { + failurePercentageConfigBuilder.setRequestVolume(failurePercentage.requestVolume()); } + configBuilder.setFailurePercentageEjection(failurePercentageConfigBuilder.build()); } + + return configBuilder.build(); } } diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancerProvider.java index 01bd2ab27f6..875af9089ed 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancerProvider.java @@ -23,6 +23,7 @@ import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; import io.grpc.internal.JsonUtil; @@ -36,8 +37,6 @@ @Internal public class CdsLoadBalancerProvider extends LoadBalancerProvider { - private static final String CLUSTER_KEY = "cluster"; - @Override public boolean isAvailable() { return true; @@ -53,9 +52,24 @@ public String getPolicyName() { return XdsLbPolicies.CDS_POLICY_NAME; } + private final LoadBalancerRegistry loadBalancerRegistry; + + public CdsLoadBalancerProvider() { + this.loadBalancerRegistry = null; + } + + public CdsLoadBalancerProvider(LoadBalancerRegistry loadBalancerRegistry) { + this.loadBalancerRegistry = loadBalancerRegistry; + } + @Override public LoadBalancer newLoadBalancer(Helper helper) { - return new CdsLoadBalancer2(helper); + LoadBalancerRegistry loadBalancerRegistry = this.loadBalancerRegistry; + if (loadBalancerRegistry == null) { + loadBalancerRegistry = LoadBalancerRegistry.getDefaultRegistry(); + } + + return new CdsLoadBalancer2(helper, loadBalancerRegistry); } @Override @@ -70,9 +84,12 @@ public ConfigOrError parseLoadBalancingPolicyConfig( */ static ConfigOrError parseLoadBalancingConfigPolicy(Map rawLoadBalancingPolicyConfig) { try { - String cluster = - JsonUtil.getString(rawLoadBalancingPolicyConfig, CLUSTER_KEY); - return ConfigOrError.fromConfig(new CdsConfig(cluster)); + String cluster = JsonUtil.getString(rawLoadBalancingPolicyConfig, "cluster"); + Boolean isDynamic = JsonUtil.getBoolean(rawLoadBalancingPolicyConfig, "is_dynamic"); + if (isDynamic == null) { + isDynamic = Boolean.FALSE; + } + return ConfigOrError.fromConfig(new CdsConfig(cluster, isDynamic)); } catch (RuntimeException e) { return ConfigOrError.fromError( Status.UNAVAILABLE.withCause(e).withDescription( @@ -89,15 +106,28 @@ static final class CdsConfig { * Name of cluster to query CDS for. */ final String name; + /** + * Whether this cluster was dynamically chosen, so the XdsDependencyManager may be unaware of + * it without an explicit cluster subscription. + */ + final boolean isDynamic; CdsConfig(String name) { + this(name, false); + } + + CdsConfig(String name, boolean isDynamic) { checkArgument(name != null && !name.isEmpty(), "name is null or empty"); this.name = name; + this.isDynamic = isDynamic; } @Override public String toString() { - return MoreObjects.toStringHelper(this).add("name", name).toString(); + return MoreObjects.toStringHelper(this) + .add("name", name) + .add("isDynamic", isDynamic) + .toString(); } } } diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 0ea2c7dd75f..64105144240 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.client.LoadStatsManager2.isEnabledOrcaLrsPropagation; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; @@ -32,9 +33,10 @@ import io.grpc.InternalLogId; import io.grpc.LoadBalancer; import io.grpc.Metadata; +import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.internal.ForwardingClientStreamTracer; -import io.grpc.internal.ObjectPool; +import io.grpc.internal.GrpcUtil; import io.grpc.services.MetricReport; import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.ForwardingSubchannel; @@ -44,6 +46,7 @@ import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; +import io.grpc.xds.client.BackendMetricPropagation; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.LoadStatsManager2.ClusterDropStats; import io.grpc.xds.client.LoadStatsManager2.ClusterLocalityStats; @@ -51,12 +54,15 @@ import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsLogger; import io.grpc.xds.client.XdsLogger.XdsLogLevel; +import io.grpc.xds.internal.XdsInternalAttributes; +import io.grpc.xds.internal.security.SecurityProtocolNegotiators; import io.grpc.xds.internal.security.SslContextProviderSupplier; import io.grpc.xds.orca.OrcaPerRequestUtil; import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicLong; @@ -81,6 +87,9 @@ final class ClusterImplLoadBalancer extends LoadBalancer { private static final Attributes.Key> ATTR_CLUSTER_LOCALITY = Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocality"); + @VisibleForTesting + static final Attributes.Key ATTR_SUBCHANNEL_ADDRESS_NAME = + Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.addressName"); private final XdsLogger logger; private final Helper helper; @@ -89,7 +98,6 @@ final class ClusterImplLoadBalancer extends LoadBalancer { private String cluster; @Nullable private String edsServiceName; - private ObjectPool xdsClientPool; private XdsClient xdsClient; private CallCounterProvider callCounterProvider; private ClusterDropStats dropStats; @@ -112,13 +120,11 @@ final class ClusterImplLoadBalancer extends LoadBalancer { public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); Attributes attributes = resolvedAddresses.getAttributes(); - if (xdsClientPool == null) { - xdsClientPool = attributes.get(InternalXdsAttributes.XDS_CLIENT_POOL); - assert xdsClientPool != null; - xdsClient = xdsClientPool.getObject(); + if (xdsClient == null) { + xdsClient = checkNotNull(attributes.get(io.grpc.xds.XdsAttributes.XDS_CLIENT), "xdsClient"); } if (callCounterProvider == null) { - callCounterProvider = attributes.get(InternalXdsAttributes.CALL_COUNTER_PROVIDER); + callCounterProvider = attributes.get(io.grpc.xds.XdsAttributes.CALL_COUNTER_PROVIDER); } ClusterImplConfig config = @@ -144,13 +150,15 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { childLbHelper.updateMaxConcurrentRequests(config.maxConcurrentRequests); childLbHelper.updateSslContextProviderSupplier(config.tlsContext); childLbHelper.updateFilterMetadata(config.filterMetadata); + childLbHelper.updateBackendMetricPropagation(config.backendMetricPropagation); - childSwitchLb.handleResolvedAddresses( + return childSwitchLb.acceptResolvedAddresses( resolvedAddresses.toBuilder() - .setAttributes(attributes) + .setAttributes(attributes.toBuilder() + .set(NameResolver.ATTR_BACKEND_SERVICE, cluster) + .build()) .setLoadBalancingPolicyConfig(config.childConfig) .build()); - return Status.OK; } @Override @@ -163,6 +171,13 @@ public void handleNameResolutionError(Status error) { } } + @Override + public void requestConnection() { + if (childSwitchLb != null) { + childSwitchLb.requestConnection(); + } + } + @Override public void shutdown() { if (dropStats != null) { @@ -175,9 +190,7 @@ public void shutdown() { childLbHelper = null; } } - if (xdsClient != null) { - xdsClient = xdsClientPool.returnObject(xdsClient); - } + xdsClient = null; } /** @@ -195,6 +208,8 @@ private final class ClusterImplLbHelper extends ForwardingLoadBalancerHelper { private Map filterMetadata = ImmutableMap.of(); @Nullable private final ServerInfo lrsServerInfo; + @Nullable + private BackendMetricPropagation backendMetricPropagation; private ClusterImplLbHelper(AtomicLong inFlights, @Nullable ServerInfo lrsServerInfo) { this.inFlights = checkNotNull(inFlights, "inFlights"); @@ -224,47 +239,67 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { args.getAddresses().get(0).getAttributes()); AtomicReference localityAtomicReference = new AtomicReference<>( clusterLocality); - Attributes attrs = args.getAttributes().toBuilder() - .set(ATTR_CLUSTER_LOCALITY, localityAtomicReference) - .build(); - args = args.toBuilder().setAddresses(addresses).setAttributes(attrs).build(); + Attributes.Builder attrsBuilder = args.getAttributes().toBuilder() + .set(ATTR_CLUSTER_LOCALITY, localityAtomicReference); + if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)) { + String hostname = args.getAddresses().get(0).getAttributes() + .get(XdsInternalAttributes.ATTR_ADDRESS_NAME); + if (hostname != null) { + attrsBuilder.set(ATTR_SUBCHANNEL_ADDRESS_NAME, hostname); + } + } + args = args.toBuilder().setAddresses(addresses).setAttributes(attrsBuilder.build()).build(); final Subchannel subchannel = delegate().createSubchannel(args); - return new ForwardingSubchannel() { - @Override - public void start(SubchannelStateListener listener) { - delegate().start(new SubchannelStateListener() { - @Override - public void onSubchannelState(ConnectivityStateInfo newState) { - if (newState.getState().equals(ConnectivityState.READY)) { - // Get locality based on the connected address attributes - ClusterLocality updatedClusterLocality = createClusterLocalityFromAttributes( - subchannel.getConnectedAddressAttributes()); - ClusterLocality oldClusterLocality = localityAtomicReference - .getAndSet(updatedClusterLocality); - oldClusterLocality.release(); + return new ClusterImplSubchannel(subchannel, localityAtomicReference); + } + + private final class ClusterImplSubchannel extends ForwardingSubchannel { + private final Subchannel delegate; + private final AtomicReference localityAtomicReference; + + private ClusterImplSubchannel( + Subchannel delegate, AtomicReference localityAtomicReference) { + this.delegate = delegate; + this.localityAtomicReference = localityAtomicReference; + } + + @Override + public void start(SubchannelStateListener listener) { + delegate().start( + new SubchannelStateListener() { + @Override + public void onSubchannelState(ConnectivityStateInfo newState) { + // Do nothing if LB has been shutdown + if (xdsClient != null && newState.getState().equals(ConnectivityState.READY)) { + // Get locality based on the connected address attributes + ClusterLocality updatedClusterLocality = + createClusterLocalityFromAttributes( + delegate.getConnectedAddressAttributes()); + ClusterLocality oldClusterLocality = + localityAtomicReference.getAndSet(updatedClusterLocality); + oldClusterLocality.release(); + } + listener.onSubchannelState(newState); } - listener.onSubchannelState(newState); - } - }); - } + }); + } - @Override - public void shutdown() { - localityAtomicReference.get().release(); - delegate().shutdown(); - } + @Override + public void shutdown() { + localityAtomicReference.get().release(); + delegate().shutdown(); + } - @Override - public void updateAddresses(List addresses) { - delegate().updateAddresses(withAdditionalAttributes(addresses)); - } + @Override + public void updateAddresses(List addresses) { + delegate().updateAddresses(withAdditionalAttributes(addresses)); + } - @Override - protected Subchannel delegate() { - return subchannel; - } - }; + @Override + protected Subchannel delegate() { + return delegate; + } } private List withAdditionalAttributes( @@ -272,10 +307,10 @@ private List withAdditionalAttributes( List newAddresses = new ArrayList<>(); for (EquivalentAddressGroup eag : addresses) { Attributes.Builder attrBuilder = eag.getAttributes().toBuilder().set( - InternalXdsAttributes.ATTR_CLUSTER_NAME, cluster); + io.grpc.xds.XdsAttributes.ATTR_CLUSTER_NAME, cluster); if (sslContextProviderSupplier != null) { attrBuilder.set( - InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, sslContextProviderSupplier); } newAddresses.add(new EquivalentAddressGroup(eag.getAddresses(), attrBuilder.build())); @@ -284,8 +319,8 @@ private List withAdditionalAttributes( } private ClusterLocality createClusterLocalityFromAttributes(Attributes addressAttributes) { - Locality locality = addressAttributes.get(InternalXdsAttributes.ATTR_LOCALITY); - String localityName = addressAttributes.get(InternalXdsAttributes.ATTR_LOCALITY_NAME); + Locality locality = addressAttributes.get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY); + String localityName = addressAttributes.get(EquivalentAddressGroup.ATTR_LOCALITY_NAME); // Endpoint addresses resolved by ClusterResolverLoadBalancer should always contain // attributes with its locality, including endpoints in LOGICAL_DNS clusters. @@ -300,7 +335,7 @@ private ClusterLocality createClusterLocalityFromAttributes(Attributes addressAt (lrsServerInfo == null) ? null : xdsClient.addClusterLocalityStats(lrsServerInfo, cluster, - edsServiceName, locality); + edsServiceName, locality, backendMetricPropagation); return new ClusterLocality(localityStats, localityName); } @@ -350,6 +385,11 @@ private void updateFilterMetadata(Map filterMetadata) { this.filterMetadata = ImmutableMap.copyOf(filterMetadata); } + private void updateBackendMetricPropagation( + @Nullable BackendMetricPropagation backendMetricPropagation) { + this.backendMetricPropagation = backendMetricPropagation; + } + private class RequestLimitingSubchannelPicker extends SubchannelPicker { private final SubchannelPicker delegate; private final List dropPolicies; @@ -369,6 +409,7 @@ private RequestLimitingSubchannelPicker(SubchannelPicker delegate, public PickResult pickSubchannel(PickSubchannelArgs args) { args.getCallOptions().getOption(ClusterImplLoadBalancerProvider.FILTER_METADATA_CONSUMER) .accept(filterMetadata); + args.getPickDetailsConsumer().addOptionalLabel("grpc.lb.backend_service", cluster); for (DropOverload dropOverload : dropPolicies) { int rand = random.nextInt(1_000_000); if (rand < dropOverload.dropsPerMillion()) { @@ -381,15 +422,21 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { Status.UNAVAILABLE.withDescription("Dropped: " + dropOverload.category())); } } - final PickResult result = delegate.pickSubchannel(args); + PickResult result = delegate.pickSubchannel(args); if (result.getStatus().isOk() && result.getSubchannel() != null) { + Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof ClusterImplLbHelper.ClusterImplSubchannel) { + subchannel = ((ClusterImplLbHelper.ClusterImplSubchannel) subchannel).delegate(); + result = result.copyWithSubchannel(subchannel); + } if (enableCircuitBreaking) { if (inFlights.get() >= maxConcurrentRequests) { if (dropStats != null) { dropStats.recordDroppedRequest(); } return PickResult.withDrop(Status.UNAVAILABLE.withDescription( - "Cluster max concurrent requests limit exceeded")); + String.format(Locale.US, "Cluster max concurrent requests limit of %d exceeded", + maxConcurrentRequests))); } } final AtomicReference clusterLocality = @@ -407,9 +454,15 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { stats, inFlights, result.getStreamTracerFactory()); ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance() .newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats)); - return PickResult.withSubchannel(result.getSubchannel(), orcaTracerFactory); + result = result.copyWithStreamTracerFactory(orcaTracerFactory); } } + if (args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY) != null + && args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY)) { + result = PickResult.withSubchannel(result.getSubchannel(), + result.getStreamTracerFactory(), + result.getSubchannel().getAttributes().get(ATTR_SUBCHANNEL_ADDRESS_NAME)); + } } return result; } @@ -475,11 +528,19 @@ private OrcaPerRpcListener(ClusterLocalityStats stats) { } /** - * Copies {@link MetricReport#getNamedMetrics()} to {@link ClusterLocalityStats} such that it is - * included in the snapshot for the LRS report sent to the LRS server. + * Copies ORCA metrics from {@link MetricReport} to {@link ClusterLocalityStats} + * such that they are included in the snapshot for the LRS report sent to the LRS server. + * This includes both top-level metrics (CPU, memory, application utilization) and named + * metrics, filtered according to the backend metric propagation configuration. */ @Override public void onLoadReport(MetricReport report) { + if (isEnabledOrcaLrsPropagation) { + stats.recordTopLevelMetrics( + report.getCpuUtilization(), + report.getMemoryUtilization(), + report.getApplicationUtilization()); + } stats.recordBackendLoadMetricStats(report.getNamedMetrics()); } } diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancerProvider.java index 4c9c14ba5f5..f369c3b99b4 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancerProvider.java @@ -31,6 +31,7 @@ import io.grpc.Status; import io.grpc.xds.Endpoints.DropOverload; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.client.BackendMetricPropagation; import io.grpc.xds.client.Bootstrapper.ServerInfo; import java.util.ArrayList; import java.util.Collections; @@ -98,11 +99,14 @@ static final class ClusterImplConfig { // Provides the direct child policy and its config. final Object childConfig; final Map filterMetadata; + @Nullable + final BackendMetricPropagation backendMetricPropagation; ClusterImplConfig(String cluster, @Nullable String edsServiceName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, List dropCategories, Object childConfig, - @Nullable UpstreamTlsContext tlsContext, Map filterMetadata) { + @Nullable UpstreamTlsContext tlsContext, Map filterMetadata, + @Nullable BackendMetricPropagation backendMetricPropagation) { this.cluster = checkNotNull(cluster, "cluster"); this.edsServiceName = edsServiceName; this.lrsServerInfo = lrsServerInfo; @@ -112,6 +116,7 @@ static final class ClusterImplConfig { this.dropCategories = Collections.unmodifiableList( new ArrayList<>(checkNotNull(dropCategories, "dropCategories"))); this.childConfig = checkNotNull(childConfig, "childConfig"); + this.backendMetricPropagation = backendMetricPropagation; } @Override diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java index c175b847c63..22b5aaa7d73 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java @@ -77,55 +77,41 @@ protected ChildLbState createChildLbState(Object key) { @Override protected Map createChildAddressesMap( ResolvedAddresses resolvedAddresses) { + lastResolvedAddresses = resolvedAddresses; + ClusterManagerConfig config = (ClusterManagerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - Map childAddresses = new HashMap<>(); - if (config != null) { - for (Map.Entry childPolicy : config.childPolicies.entrySet()) { - ResolvedAddresses addresses = resolvedAddresses.toBuilder() - .setLoadBalancingPolicyConfig(childPolicy.getValue()) - .build(); - childAddresses.put(childPolicy.getKey(), addresses); - } - } logger.log( XdsLogLevel.INFO, - "Received cluster_manager lb config: child names={0}", childAddresses.keySet()); - return childAddresses; - } + "Received cluster_manager lb config: child names={0}", config.childPolicies.keySet()); + Map childAddresses = new HashMap<>(); - /** - * This is like the parent except that it doesn't shutdown the removed children since we want that - * to be done by the timer. - */ - @Override - public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - if (lastResolvedAddresses != null) { - // Handle deactivated children - ClusterManagerConfig config = (ClusterManagerConfig) - resolvedAddresses.getLoadBalancingPolicyConfig(); - ClusterManagerConfig lastConfig = (ClusterManagerConfig) - lastResolvedAddresses.getLoadBalancingPolicyConfig(); - Map adjChildPolicies = new HashMap<>(config.childPolicies); - for (Entry entry : lastConfig.childPolicies.entrySet()) { - ClusterManagerLbState state = (ClusterManagerLbState) getChildLbState(entry.getKey()); - if (adjChildPolicies.containsKey(entry.getKey())) { - if (state.deletionTimer != null) { - state.reactivateChild(); - } - } else if (state != null) { - adjChildPolicies.put(entry.getKey(), entry.getValue()); - if (state.deletionTimer == null) { - state.deactivateChild(); - } + // Reactivate children with config; deactivate children without config + for (ChildLbState rawState : getChildLbStates()) { + ClusterManagerLbState state = (ClusterManagerLbState) rawState; + if (config.childPolicies.containsKey(state.getKey())) { + // Active child + if (state.deletionTimer != null) { + state.reactivateChild(); + } + } else { + // Inactive child + if (state.deletionTimer == null) { + state.deactivateChild(); + } + if (state.deletionTimer.isPending()) { + childAddresses.put(state.getKey(), null); // Preserve child, without config update } } - config = new ClusterManagerConfig(adjChildPolicies); - resolvedAddresses = - resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(config).build(); } - lastResolvedAddresses = resolvedAddresses; - return super.acceptResolvedAddresses(resolvedAddresses); + + for (Map.Entry childPolicy : config.childPolicies.entrySet()) { + ResolvedAddresses addresses = resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(childPolicy.getValue()) + .build(); + childAddresses.put(childPolicy.getKey(), addresses); + } + return childAddresses; } /** @@ -232,14 +218,6 @@ class DeletionTask implements Runnable { @Override public void run() { - ClusterManagerConfig config = (ClusterManagerConfig) - lastResolvedAddresses.getLoadBalancingPolicyConfig(); - Map childPolicies = new HashMap<>(config.childPolicies); - Object removed = childPolicies.remove(getKey()); - assert removed != null; - config = new ClusterManagerConfig(childPolicies); - lastResolvedAddresses = - lastResolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(config).build(); acceptResolvedAddresses(lastResolvedAddresses); } } diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java deleted file mode 100644 index 875a6c45020..00000000000 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java +++ /dev/null @@ -1,865 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; -import com.google.protobuf.Struct; -import io.grpc.Attributes; -import io.grpc.EquivalentAddressGroup; -import io.grpc.InternalLogId; -import io.grpc.LoadBalancer; -import io.grpc.LoadBalancerProvider; -import io.grpc.LoadBalancerRegistry; -import io.grpc.NameResolver; -import io.grpc.NameResolver.ResolutionResult; -import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.SynchronizationContext.ScheduledHandle; -import io.grpc.internal.BackoffPolicy; -import io.grpc.internal.ExponentialBackoffPolicy; -import io.grpc.internal.ObjectPool; -import io.grpc.util.ForwardingLoadBalancerHelper; -import io.grpc.util.GracefulSwitchLoadBalancer; -import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; -import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig.DiscoveryMechanism; -import io.grpc.xds.Endpoints.DropOverload; -import io.grpc.xds.Endpoints.LbEndpoint; -import io.grpc.xds.Endpoints.LocalityLbEndpoints; -import io.grpc.xds.EnvoyServerProtoData.FailurePercentageEjection; -import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; -import io.grpc.xds.EnvoyServerProtoData.SuccessRateEjection; -import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig; -import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig; -import io.grpc.xds.XdsEndpointResource.EdsUpdate; -import io.grpc.xds.client.Bootstrapper.ServerInfo; -import io.grpc.xds.client.Locality; -import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsClient.ResourceWatcher; -import io.grpc.xds.client.XdsLogger; -import io.grpc.xds.client.XdsLogger.XdsLogLevel; -import java.net.URI; -import java.net.URISyntaxException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.TreeMap; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import javax.annotation.Nullable; - -/** - * Load balancer for cluster_resolver_experimental LB policy. This LB policy is the child LB policy - * of the cds_experimental LB policy and the parent LB policy of the priority_experimental LB - * policy in the xDS load balancing hierarchy. This policy resolves endpoints of non-aggregate - * clusters (e.g., EDS or Logical DNS) and groups endpoints in priorities and localities to be - * used in the downstream LB policies for fine-grained load balancing purposes. - */ -final class ClusterResolverLoadBalancer extends LoadBalancer { - // DNS-resolved endpoints do not have the definition of the locality it belongs to, just hardcode - // to an empty locality. - private static final Locality LOGICAL_DNS_CLUSTER_LOCALITY = Locality.create("", "", ""); - private final XdsLogger logger; - private final SynchronizationContext syncContext; - private final ScheduledExecutorService timeService; - private final LoadBalancerRegistry lbRegistry; - private final BackoffPolicy.Provider backoffPolicyProvider; - private final GracefulSwitchLoadBalancer delegate; - private ObjectPool xdsClientPool; - private XdsClient xdsClient; - private ClusterResolverConfig config; - - ClusterResolverLoadBalancer(Helper helper) { - this(helper, LoadBalancerRegistry.getDefaultRegistry(), - new ExponentialBackoffPolicy.Provider()); - } - - @VisibleForTesting - ClusterResolverLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry, - BackoffPolicy.Provider backoffPolicyProvider) { - this.lbRegistry = checkNotNull(lbRegistry, "lbRegistry"); - this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); - this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); - this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); - delegate = new GracefulSwitchLoadBalancer(helper); - logger = XdsLogger.withLogId( - InternalLogId.allocate("cluster-resolver-lb", helper.getAuthority())); - logger.log(XdsLogLevel.INFO, "Created"); - } - - @Override - public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); - if (xdsClientPool == null) { - xdsClientPool = resolvedAddresses.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL); - xdsClient = xdsClientPool.getObject(); - } - ClusterResolverConfig config = - (ClusterResolverConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - if (!Objects.equals(this.config, config)) { - logger.log(XdsLogLevel.DEBUG, "Config: {0}", config); - this.config = config; - Object gracefulConfig = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( - new ClusterResolverLbStateFactory(), config); - delegate.handleResolvedAddresses( - resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(gracefulConfig).build()); - } - return Status.OK; - } - - @Override - public void handleNameResolutionError(Status error) { - logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error); - delegate.handleNameResolutionError(error); - } - - @Override - public void shutdown() { - logger.log(XdsLogLevel.INFO, "Shutdown"); - delegate.shutdown(); - if (xdsClientPool != null) { - xdsClientPool.returnObject(xdsClient); - } - } - - private final class ClusterResolverLbStateFactory extends LoadBalancer.Factory { - @Override - public LoadBalancer newLoadBalancer(Helper helper) { - return new ClusterResolverLbState(helper); - } - } - - /** - * The state of a cluster_resolver LB working session. A new instance is created whenever - * the cluster_resolver LB receives a new config. The old instance is replaced when the - * new one is ready to handle new RPCs. - */ - private final class ClusterResolverLbState extends LoadBalancer { - private final Helper helper; - private final List clusters = new ArrayList<>(); - private final Map clusterStates = new HashMap<>(); - private Object endpointLbConfig; - private ResolvedAddresses resolvedAddresses; - private LoadBalancer childLb; - - ClusterResolverLbState(Helper helper) { - this.helper = new RefreshableHelper(checkNotNull(helper, "helper")); - logger.log(XdsLogLevel.DEBUG, "New ClusterResolverLbState"); - } - - @Override - public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - this.resolvedAddresses = resolvedAddresses; - ClusterResolverConfig config = - (ClusterResolverConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - endpointLbConfig = config.lbConfig; - for (DiscoveryMechanism instance : config.discoveryMechanisms) { - clusters.add(instance.cluster); - ClusterState state; - if (instance.type == DiscoveryMechanism.Type.EDS) { - state = new EdsClusterState(instance.cluster, instance.edsServiceName, - instance.lrsServerInfo, instance.maxConcurrentRequests, instance.tlsContext, - instance.filterMetadata, instance.outlierDetection); - } else { // logical DNS - state = new LogicalDnsClusterState(instance.cluster, instance.dnsHostName, - instance.lrsServerInfo, instance.maxConcurrentRequests, instance.tlsContext, - instance.filterMetadata); - } - clusterStates.put(instance.cluster, state); - state.start(); - } - return Status.OK; - } - - @Override - public void handleNameResolutionError(Status error) { - if (childLb != null) { - childLb.handleNameResolutionError(error); - } else { - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); - } - } - - @Override - public void shutdown() { - for (ClusterState state : clusterStates.values()) { - state.shutdown(); - } - if (childLb != null) { - childLb.shutdown(); - } - } - - private void handleEndpointResourceUpdate() { - List addresses = new ArrayList<>(); - Map priorityChildConfigs = new HashMap<>(); - List priorities = new ArrayList<>(); // totally ordered priority list - - Status endpointNotFound = Status.OK; - for (String cluster : clusters) { - ClusterState state = clusterStates.get(cluster); - // Propagate endpoints to the child LB policy only after all clusters have been resolved. - if (!state.resolved && state.status.isOk()) { - return; - } - if (state.result != null) { - addresses.addAll(state.result.addresses); - priorityChildConfigs.putAll(state.result.priorityChildConfigs); - priorities.addAll(state.result.priorities); - } else { - endpointNotFound = state.status; - } - } - if (addresses.isEmpty()) { - if (endpointNotFound.isOk()) { - endpointNotFound = Status.UNAVAILABLE.withDescription( - "No usable endpoint from cluster(s): " + clusters); - } else { - endpointNotFound = - Status.UNAVAILABLE.withCause(endpointNotFound.getCause()) - .withDescription(endpointNotFound.getDescription()); - } - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(endpointNotFound))); - if (childLb != null) { - childLb.shutdown(); - childLb = null; - } - return; - } - PriorityLbConfig childConfig = - new PriorityLbConfig(Collections.unmodifiableMap(priorityChildConfigs), - Collections.unmodifiableList(priorities)); - if (childLb == null) { - childLb = lbRegistry.getProvider(PRIORITY_POLICY_NAME).newLoadBalancer(helper); - } - childLb.handleResolvedAddresses( - resolvedAddresses.toBuilder() - .setLoadBalancingPolicyConfig(childConfig) - .setAddresses(Collections.unmodifiableList(addresses)) - .build()); - } - - private void handleEndpointResolutionError() { - boolean allInError = true; - Status error = null; - for (String cluster : clusters) { - ClusterState state = clusterStates.get(cluster); - if (state.status.isOk()) { - allInError = false; - } else { - error = state.status; - } - } - if (allInError) { - if (childLb != null) { - childLb.handleNameResolutionError(error); - } else { - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); - } - } - } - - /** - * Wires re-resolution requests from downstream LB policies with DNS resolver. - */ - private final class RefreshableHelper extends ForwardingLoadBalancerHelper { - private final Helper delegate; - - private RefreshableHelper(Helper delegate) { - this.delegate = checkNotNull(delegate, "delegate"); - } - - @Override - public void refreshNameResolution() { - for (ClusterState state : clusterStates.values()) { - if (state instanceof LogicalDnsClusterState) { - ((LogicalDnsClusterState) state).refresh(); - } - } - } - - @Override - protected Helper delegate() { - return delegate; - } - } - - /** - * Resolution state of an underlying cluster. - */ - private abstract class ClusterState { - // Name of the cluster to be resolved. - protected final String name; - @Nullable - protected final ServerInfo lrsServerInfo; - @Nullable - protected final Long maxConcurrentRequests; - @Nullable - protected final UpstreamTlsContext tlsContext; - protected final Map filterMetadata; - @Nullable - protected final OutlierDetection outlierDetection; - // Resolution status, may contain most recent error encountered. - protected Status status = Status.OK; - // True if has received resolution result. - protected boolean resolved; - // Most recently resolved addresses and config, or null if resource not exists. - @Nullable - protected ClusterResolutionResult result; - - protected boolean shutdown; - - private ClusterState(String name, @Nullable ServerInfo lrsServerInfo, - @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext, - Map filterMetadata, @Nullable OutlierDetection outlierDetection) { - this.name = name; - this.lrsServerInfo = lrsServerInfo; - this.maxConcurrentRequests = maxConcurrentRequests; - this.tlsContext = tlsContext; - this.filterMetadata = ImmutableMap.copyOf(filterMetadata); - this.outlierDetection = outlierDetection; - } - - abstract void start(); - - void shutdown() { - shutdown = true; - } - } - - private final class EdsClusterState extends ClusterState implements ResourceWatcher { - @Nullable - private final String edsServiceName; - private Map localityPriorityNames = Collections.emptyMap(); - int priorityNameGenId = 1; - - private EdsClusterState(String name, @Nullable String edsServiceName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, Map filterMetadata, - @Nullable OutlierDetection outlierDetection) { - super(name, lrsServerInfo, maxConcurrentRequests, tlsContext, filterMetadata, - outlierDetection); - this.edsServiceName = edsServiceName; - } - - @Override - void start() { - String resourceName = edsServiceName != null ? edsServiceName : name; - logger.log(XdsLogLevel.INFO, "Start watching EDS resource {0}", resourceName); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), - resourceName, this, syncContext); - } - - @Override - protected void shutdown() { - super.shutdown(); - String resourceName = edsServiceName != null ? edsServiceName : name; - logger.log(XdsLogLevel.INFO, "Stop watching EDS resource {0}", resourceName); - xdsClient.cancelXdsResourceWatch(XdsEndpointResource.getInstance(), resourceName, this); - } - - @Override - public void onChanged(final EdsUpdate update) { - class EndpointsUpdated implements Runnable { - @Override - public void run() { - if (shutdown) { - return; - } - logger.log(XdsLogLevel.DEBUG, "Received endpoint update {0}", update); - if (logger.isLoggable(XdsLogLevel.INFO)) { - logger.log(XdsLogLevel.INFO, "Cluster {0}: {1} localities, {2} drop categories", - update.clusterName, update.localityLbEndpointsMap.size(), - update.dropPolicies.size()); - } - Map localityLbEndpoints = - update.localityLbEndpointsMap; - List dropOverloads = update.dropPolicies; - List addresses = new ArrayList<>(); - Map> prioritizedLocalityWeights = new HashMap<>(); - List sortedPriorityNames = generatePriorityNames(name, localityLbEndpoints); - for (Locality locality : localityLbEndpoints.keySet()) { - LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); - String priorityName = localityPriorityNames.get(locality); - boolean discard = true; - for (LbEndpoint endpoint : localityLbInfo.endpoints()) { - if (endpoint.isHealthy()) { - discard = false; - long weight = localityLbInfo.localityWeight(); - if (endpoint.loadBalancingWeight() != 0) { - weight *= endpoint.loadBalancingWeight(); - } - String localityName = localityName(locality); - Attributes attr = - endpoint.eag().getAttributes().toBuilder() - .set(InternalXdsAttributes.ATTR_LOCALITY, locality) - .set(InternalXdsAttributes.ATTR_LOCALITY_NAME, localityName) - .set(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT, - localityLbInfo.localityWeight()) - .set(InternalXdsAttributes.ATTR_SERVER_WEIGHT, weight) - .build(); - EquivalentAddressGroup eag = new EquivalentAddressGroup( - endpoint.eag().getAddresses(), attr); - eag = AddressFilter.setPathFilter(eag, Arrays.asList(priorityName, localityName)); - addresses.add(eag); - } - } - if (discard) { - logger.log(XdsLogLevel.INFO, - "Discard locality {0} with 0 healthy endpoints", locality); - continue; - } - if (!prioritizedLocalityWeights.containsKey(priorityName)) { - prioritizedLocalityWeights.put(priorityName, new HashMap()); - } - prioritizedLocalityWeights.get(priorityName).put( - locality, localityLbInfo.localityWeight()); - } - if (prioritizedLocalityWeights.isEmpty()) { - // Will still update the result, as if the cluster resource is revoked. - logger.log(XdsLogLevel.INFO, - "Cluster {0} has no usable priority/locality/endpoint", update.clusterName); - } - sortedPriorityNames.retainAll(prioritizedLocalityWeights.keySet()); - Map priorityChildConfigs = - generateEdsBasedPriorityChildConfigs( - name, edsServiceName, lrsServerInfo, maxConcurrentRequests, tlsContext, - filterMetadata, outlierDetection, endpointLbConfig, lbRegistry, - prioritizedLocalityWeights, dropOverloads); - status = Status.OK; - resolved = true; - result = new ClusterResolutionResult(addresses, priorityChildConfigs, - sortedPriorityNames); - handleEndpointResourceUpdate(); - } - } - - new EndpointsUpdated().run(); - } - - private List generatePriorityNames(String name, - Map localityLbEndpoints) { - TreeMap> todo = new TreeMap<>(); - for (Locality locality : localityLbEndpoints.keySet()) { - int priority = localityLbEndpoints.get(locality).priority(); - if (!todo.containsKey(priority)) { - todo.put(priority, new ArrayList<>()); - } - todo.get(priority).add(locality); - } - Map newNames = new HashMap<>(); - Set usedNames = new HashSet<>(); - List ret = new ArrayList<>(); - for (Integer priority: todo.keySet()) { - String foundName = ""; - for (Locality locality : todo.get(priority)) { - if (localityPriorityNames.containsKey(locality) - && usedNames.add(localityPriorityNames.get(locality))) { - foundName = localityPriorityNames.get(locality); - break; - } - } - if ("".equals(foundName)) { - foundName = String.format(Locale.US, "%s[child%d]", name, priorityNameGenId++); - } - for (Locality locality : todo.get(priority)) { - newNames.put(locality, foundName); - } - ret.add(foundName); - } - localityPriorityNames = newNames; - return ret; - } - - @Override - public void onResourceDoesNotExist(final String resourceName) { - if (shutdown) { - return; - } - logger.log(XdsLogLevel.INFO, "Resource {0} unavailable", resourceName); - status = Status.OK; - resolved = true; - result = null; // resource revoked - handleEndpointResourceUpdate(); - } - - @Override - public void onError(final Status error) { - if (shutdown) { - return; - } - String resourceName = edsServiceName != null ? edsServiceName : name; - status = Status.UNAVAILABLE - .withDescription(String.format("Unable to load EDS %s. xDS server returned: %s: %s", - resourceName, error.getCode(), error.getDescription())) - .withCause(error.getCause()); - logger.log(XdsLogLevel.WARNING, "Received EDS error: {0}", error); - handleEndpointResolutionError(); - } - } - - private final class LogicalDnsClusterState extends ClusterState { - private final String dnsHostName; - private final NameResolver.Factory nameResolverFactory; - private final NameResolver.Args nameResolverArgs; - private NameResolver resolver; - @Nullable - private BackoffPolicy backoffPolicy; - @Nullable - private ScheduledHandle scheduledRefresh; - - private LogicalDnsClusterState(String name, String dnsHostName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, Map filterMetadata) { - super(name, lrsServerInfo, maxConcurrentRequests, tlsContext, filterMetadata, null); - this.dnsHostName = checkNotNull(dnsHostName, "dnsHostName"); - nameResolverFactory = - checkNotNull(helper.getNameResolverRegistry().asFactory(), "nameResolverFactory"); - nameResolverArgs = checkNotNull(helper.getNameResolverArgs(), "nameResolverArgs"); - } - - @Override - void start() { - URI uri; - try { - uri = new URI("dns", "", "/" + dnsHostName, null); - } catch (URISyntaxException e) { - status = Status.INTERNAL.withDescription( - "Bug, invalid URI creation: " + dnsHostName).withCause(e); - handleEndpointResolutionError(); - return; - } - resolver = nameResolverFactory.newNameResolver(uri, nameResolverArgs); - if (resolver == null) { - status = Status.INTERNAL.withDescription("Xds cluster resolver lb for logical DNS " - + "cluster [" + name + "] cannot find DNS resolver with uri:" + uri); - handleEndpointResolutionError(); - return; - } - resolver.start(new NameResolverListener()); - } - - void refresh() { - if (resolver == null) { - return; - } - cancelBackoff(); - resolver.refresh(); - } - - @Override - void shutdown() { - super.shutdown(); - if (resolver != null) { - resolver.shutdown(); - } - cancelBackoff(); - } - - private void cancelBackoff() { - if (scheduledRefresh != null) { - scheduledRefresh.cancel(); - scheduledRefresh = null; - backoffPolicy = null; - } - } - - private class DelayedNameResolverRefresh implements Runnable { - @Override - public void run() { - scheduledRefresh = null; - if (!shutdown) { - resolver.refresh(); - } - } - } - - private class NameResolverListener extends NameResolver.Listener2 { - @Override - public void onResult(final ResolutionResult resolutionResult) { - class NameResolved implements Runnable { - @Override - public void run() { - if (shutdown) { - return; - } - backoffPolicy = null; // reset backoff sequence if succeeded - // Arbitrary priority notation for all DNS-resolved endpoints. - String priorityName = priorityName(name, 0); // value doesn't matter - List addresses = new ArrayList<>(); - for (EquivalentAddressGroup eag : resolutionResult.getAddresses()) { - // No weight attribute is attached, all endpoint-level LB policy should be able - // to handle such it. - String localityName = localityName(LOGICAL_DNS_CLUSTER_LOCALITY); - Attributes attr = eag.getAttributes().toBuilder() - .set(InternalXdsAttributes.ATTR_LOCALITY, LOGICAL_DNS_CLUSTER_LOCALITY) - .set(InternalXdsAttributes.ATTR_LOCALITY_NAME, localityName) - .build(); - eag = new EquivalentAddressGroup(eag.getAddresses(), attr); - eag = AddressFilter.setPathFilter(eag, Arrays.asList(priorityName, localityName)); - addresses.add(eag); - } - PriorityChildConfig priorityChildConfig = generateDnsBasedPriorityChildConfig( - name, lrsServerInfo, maxConcurrentRequests, tlsContext, filterMetadata, - lbRegistry, Collections.emptyList()); - status = Status.OK; - resolved = true; - result = new ClusterResolutionResult(addresses, priorityName, priorityChildConfig); - handleEndpointResourceUpdate(); - } - } - - syncContext.execute(new NameResolved()); - } - - @Override - public void onError(final Status error) { - syncContext.execute(new Runnable() { - @Override - public void run() { - if (shutdown) { - return; - } - status = error; - // NameResolver.Listener API cannot distinguish between address-not-found and - // transient errors. If the error occurs in the first resolution, treat it as - // address not found. Otherwise, either there is previously resolved addresses - // previously encountered error, propagate the error to downstream/upstream and - // let downstream/upstream handle it. - if (!resolved) { - resolved = true; - handleEndpointResourceUpdate(); - } else { - handleEndpointResolutionError(); - } - if (scheduledRefresh != null && scheduledRefresh.isPending()) { - return; - } - if (backoffPolicy == null) { - backoffPolicy = backoffPolicyProvider.get(); - } - long delayNanos = backoffPolicy.nextBackoffNanos(); - logger.log(XdsLogLevel.DEBUG, - "Logical DNS resolver for cluster {0} encountered name resolution " - + "error: {1}, scheduling DNS resolution backoff for {2} ns", - name, error, delayNanos); - scheduledRefresh = - syncContext.schedule( - new DelayedNameResolverRefresh(), delayNanos, TimeUnit.NANOSECONDS, - timeService); - } - }); - } - } - } - } - - private static class ClusterResolutionResult { - // Endpoint addresses. - private final List addresses; - // Config (include load balancing policy/config) for each priority in the cluster. - private final Map priorityChildConfigs; - // List of priority names ordered in descending priorities. - private final List priorities; - - ClusterResolutionResult(List addresses, String priority, - PriorityChildConfig config) { - this(addresses, Collections.singletonMap(priority, config), - Collections.singletonList(priority)); - } - - ClusterResolutionResult(List addresses, - Map configs, List priorities) { - this.addresses = addresses; - this.priorityChildConfigs = configs; - this.priorities = priorities; - } - } - - /** - * Generates the config to be used in the priority LB policy for the single priority of - * logical DNS cluster. - * - *

priority LB -> cluster_impl LB (single hardcoded priority) -> pick_first - */ - private static PriorityChildConfig generateDnsBasedPriorityChildConfig( - String cluster, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, Map filterMetadata, - LoadBalancerRegistry lbRegistry, List dropOverloads) { - // Override endpoint-level LB policy with pick_first for logical DNS cluster. - Object endpointLbConfig = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( - lbRegistry.getProvider("pick_first"), null); - ClusterImplConfig clusterImplConfig = - new ClusterImplConfig(cluster, null, lrsServerInfo, maxConcurrentRequests, - dropOverloads, endpointLbConfig, tlsContext, filterMetadata); - LoadBalancerProvider clusterImplLbProvider = - lbRegistry.getProvider(XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME); - Object clusterImplPolicy = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( - clusterImplLbProvider, clusterImplConfig); - return new PriorityChildConfig(clusterImplPolicy, false /* ignoreReresolution*/); - } - - /** - * Generates configs to be used in the priority LB policy for priorities in an EDS cluster. - * - *

priority LB -> cluster_impl LB (one per priority) -> (weighted_target LB - * -> round_robin / least_request_experimental (one per locality)) / ring_hash_experimental - */ - private static Map generateEdsBasedPriorityChildConfigs( - String cluster, @Nullable String edsServiceName, @Nullable ServerInfo lrsServerInfo, - @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext, - Map filterMetadata, - @Nullable OutlierDetection outlierDetection, Object endpointLbConfig, - LoadBalancerRegistry lbRegistry, Map> prioritizedLocalityWeights, List dropOverloads) { - Map configs = new HashMap<>(); - for (String priority : prioritizedLocalityWeights.keySet()) { - ClusterImplConfig clusterImplConfig = - new ClusterImplConfig(cluster, edsServiceName, lrsServerInfo, maxConcurrentRequests, - dropOverloads, endpointLbConfig, tlsContext, filterMetadata); - LoadBalancerProvider clusterImplLbProvider = - lbRegistry.getProvider(XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME); - Object priorityChildPolicy = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( - clusterImplLbProvider, clusterImplConfig); - - // If outlier detection has been configured we wrap the child policy in the outlier detection - // load balancer. - if (outlierDetection != null) { - LoadBalancerProvider outlierDetectionProvider = lbRegistry.getProvider( - "outlier_detection_experimental"); - priorityChildPolicy = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( - outlierDetectionProvider, - buildOutlierDetectionLbConfig(outlierDetection, priorityChildPolicy)); - } - - PriorityChildConfig priorityChildConfig = - new PriorityChildConfig(priorityChildPolicy, true /* ignoreReresolution */); - configs.put(priority, priorityChildConfig); - } - return configs; - } - - /** - * Converts {@link OutlierDetection} that represents the xDS configuration to {@link - * OutlierDetectionLoadBalancerConfig} that the {@link io.grpc.util.OutlierDetectionLoadBalancer} - * understands. - */ - private static OutlierDetectionLoadBalancerConfig buildOutlierDetectionLbConfig( - OutlierDetection outlierDetection, Object childConfig) { - OutlierDetectionLoadBalancerConfig.Builder configBuilder - = new OutlierDetectionLoadBalancerConfig.Builder(); - - configBuilder.setChildConfig(childConfig); - - if (outlierDetection.intervalNanos() != null) { - configBuilder.setIntervalNanos(outlierDetection.intervalNanos()); - } - if (outlierDetection.baseEjectionTimeNanos() != null) { - configBuilder.setBaseEjectionTimeNanos(outlierDetection.baseEjectionTimeNanos()); - } - if (outlierDetection.maxEjectionTimeNanos() != null) { - configBuilder.setMaxEjectionTimeNanos(outlierDetection.maxEjectionTimeNanos()); - } - if (outlierDetection.maxEjectionPercent() != null) { - configBuilder.setMaxEjectionPercent(outlierDetection.maxEjectionPercent()); - } - - SuccessRateEjection successRate = outlierDetection.successRateEjection(); - if (successRate != null) { - OutlierDetectionLoadBalancerConfig.SuccessRateEjection.Builder - successRateConfigBuilder = new OutlierDetectionLoadBalancerConfig - .SuccessRateEjection.Builder(); - - if (successRate.stdevFactor() != null) { - successRateConfigBuilder.setStdevFactor(successRate.stdevFactor()); - } - if (successRate.enforcementPercentage() != null) { - successRateConfigBuilder.setEnforcementPercentage(successRate.enforcementPercentage()); - } - if (successRate.minimumHosts() != null) { - successRateConfigBuilder.setMinimumHosts(successRate.minimumHosts()); - } - if (successRate.requestVolume() != null) { - successRateConfigBuilder.setRequestVolume(successRate.requestVolume()); - } - - configBuilder.setSuccessRateEjection(successRateConfigBuilder.build()); - } - - FailurePercentageEjection failurePercentage = outlierDetection.failurePercentageEjection(); - if (failurePercentage != null) { - OutlierDetectionLoadBalancerConfig.FailurePercentageEjection.Builder - failurePercentageConfigBuilder = new OutlierDetectionLoadBalancerConfig - .FailurePercentageEjection.Builder(); - - if (failurePercentage.threshold() != null) { - failurePercentageConfigBuilder.setThreshold(failurePercentage.threshold()); - } - if (failurePercentage.enforcementPercentage() != null) { - failurePercentageConfigBuilder.setEnforcementPercentage( - failurePercentage.enforcementPercentage()); - } - if (failurePercentage.minimumHosts() != null) { - failurePercentageConfigBuilder.setMinimumHosts(failurePercentage.minimumHosts()); - } - if (failurePercentage.requestVolume() != null) { - failurePercentageConfigBuilder.setRequestVolume(failurePercentage.requestVolume()); - } - - configBuilder.setFailurePercentageEjection(failurePercentageConfigBuilder.build()); - } - - return configBuilder.build(); - } - - /** - * Generates a string that represents the priority in the LB policy config. The string is unique - * across priorities in all clusters and priorityName(c, p1) < priorityName(c, p2) iff p1 < p2. - * The ordering is undefined for priorities in different clusters. - */ - private static String priorityName(String cluster, int priority) { - return cluster + "[child" + priority + "]"; - } - - /** - * Generates a string that represents the locality in the LB policy config. The string is unique - * across all localities in all clusters. - */ - private static String localityName(Locality locality) { - return "{region=\"" + locality.region() - + "\", zone=\"" + locality.zone() - + "\", sub_zone=\"" + locality.subZone() - + "\"}"; - } -} diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java deleted file mode 100644 index 2301cb670e0..00000000000 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.base.Preconditions.checkNotNull; - -import com.google.common.base.MoreObjects; -import com.google.common.collect.ImmutableMap; -import com.google.protobuf.Struct; -import io.grpc.Internal; -import io.grpc.LoadBalancer; -import io.grpc.LoadBalancer.Helper; -import io.grpc.LoadBalancerProvider; -import io.grpc.NameResolver.ConfigOrError; -import io.grpc.Status; -import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; -import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.client.Bootstrapper.ServerInfo; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import javax.annotation.Nullable; - -/** - * The provider for the cluster_resolver load balancing policy. This class should not be directly - * referenced in code. The policy should be accessed through - * {@link io.grpc.LoadBalancerRegistry#getProvider} with the name "cluster_resolver_experimental". - */ -@Internal -public final class ClusterResolverLoadBalancerProvider extends LoadBalancerProvider { - - @Override - public boolean isAvailable() { - return true; - } - - @Override - public int getPriority() { - return 5; - } - - @Override - public String getPolicyName() { - return XdsLbPolicies.CLUSTER_RESOLVER_POLICY_NAME; - } - - @Override - public ConfigOrError parseLoadBalancingPolicyConfig(Map rawLoadBalancingPolicyConfig) { - return ConfigOrError.fromError( - Status.INTERNAL.withDescription(getPolicyName() + " cannot be used from service config")); - } - - @Override - public LoadBalancer newLoadBalancer(Helper helper) { - return new ClusterResolverLoadBalancer(helper); - } - - static final class ClusterResolverConfig { - // Ordered list of clusters to be resolved. - final List discoveryMechanisms; - // GracefulSwitch configuration - final Object lbConfig; - - ClusterResolverConfig(List discoveryMechanisms, Object lbConfig) { - this.discoveryMechanisms = checkNotNull(discoveryMechanisms, "discoveryMechanisms"); - this.lbConfig = checkNotNull(lbConfig, "lbConfig"); - } - - @Override - public int hashCode() { - return Objects.hash(discoveryMechanisms, lbConfig); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ClusterResolverConfig that = (ClusterResolverConfig) o; - return discoveryMechanisms.equals(that.discoveryMechanisms) - && lbConfig.equals(that.lbConfig); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("discoveryMechanisms", discoveryMechanisms) - .add("lbConfig", lbConfig) - .toString(); - } - - // Describes the mechanism for a specific cluster. - static final class DiscoveryMechanism { - // Name of the cluster to resolve. - final String cluster; - // Type of the cluster. - final Type type; - // Load reporting server info. Null if not enabled. - @Nullable - final ServerInfo lrsServerInfo; - // Cluster-level max concurrent request threshold. Null if not specified. - @Nullable - final Long maxConcurrentRequests; - // TLS context for connections to endpoints in the cluster. - @Nullable - final UpstreamTlsContext tlsContext; - // Resource name for resolving endpoints via EDS. Only valid for EDS clusters. - @Nullable - final String edsServiceName; - // Hostname for resolving endpoints via DNS. Only valid for LOGICAL_DNS clusters. - @Nullable - final String dnsHostName; - @Nullable - final OutlierDetection outlierDetection; - final Map filterMetadata; - - enum Type { - EDS, - LOGICAL_DNS, - } - - private DiscoveryMechanism(String cluster, Type type, @Nullable String edsServiceName, - @Nullable String dnsHostName, @Nullable ServerInfo lrsServerInfo, - @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext, - Map filterMetadata, @Nullable OutlierDetection outlierDetection) { - this.cluster = checkNotNull(cluster, "cluster"); - this.type = checkNotNull(type, "type"); - this.edsServiceName = edsServiceName; - this.dnsHostName = dnsHostName; - this.lrsServerInfo = lrsServerInfo; - this.maxConcurrentRequests = maxConcurrentRequests; - this.tlsContext = tlsContext; - this.filterMetadata = ImmutableMap.copyOf(checkNotNull(filterMetadata, "filterMetadata")); - this.outlierDetection = outlierDetection; - } - - static DiscoveryMechanism forEds(String cluster, @Nullable String edsServiceName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, Map filterMetadata, - OutlierDetection outlierDetection) { - return new DiscoveryMechanism(cluster, Type.EDS, edsServiceName, null, lrsServerInfo, - maxConcurrentRequests, tlsContext, filterMetadata, outlierDetection); - } - - static DiscoveryMechanism forLogicalDns(String cluster, String dnsHostName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, Map filterMetadata) { - return new DiscoveryMechanism(cluster, Type.LOGICAL_DNS, null, dnsHostName, - lrsServerInfo, maxConcurrentRequests, tlsContext, filterMetadata, null); - } - - @Override - public int hashCode() { - return Objects.hash(cluster, type, lrsServerInfo, maxConcurrentRequests, tlsContext, - edsServiceName, dnsHostName, filterMetadata, outlierDetection); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - DiscoveryMechanism that = (DiscoveryMechanism) o; - return cluster.equals(that.cluster) - && type == that.type - && Objects.equals(edsServiceName, that.edsServiceName) - && Objects.equals(dnsHostName, that.dnsHostName) - && Objects.equals(lrsServerInfo, that.lrsServerInfo) - && Objects.equals(maxConcurrentRequests, that.maxConcurrentRequests) - && Objects.equals(tlsContext, that.tlsContext) - && Objects.equals(filterMetadata, that.filterMetadata) - && Objects.equals(outlierDetection, that.outlierDetection); - } - - @Override - public String toString() { - MoreObjects.ToStringHelper toStringHelper = - MoreObjects.toStringHelper(this) - .add("cluster", cluster) - .add("type", type) - .add("edsServiceName", edsServiceName) - .add("dnsHostName", dnsHostName) - .add("lrsServerInfo", lrsServerInfo) - // Exclude tlsContext as its string representation is cumbersome. - .add("maxConcurrentRequests", maxConcurrentRequests) - .add("filterMetadata", filterMetadata) - // Exclude outlierDetection as its string representation is long. - ; - return toStringHelper.toString(); - } - } - } -} diff --git a/xds/src/main/java/io/grpc/xds/CsdsService.java b/xds/src/main/java/io/grpc/xds/CsdsService.java index a296beb45d0..8c2fe333c15 100644 --- a/xds/src/main/java/io/grpc/xds/CsdsService.java +++ b/xds/src/main/java/io/grpc/xds/CsdsService.java @@ -249,6 +249,8 @@ static ClientResourceStatus metadataStatusToClientStatus(ResourceMetadataStatus return ClientResourceStatus.ACKED; case NACKED: return ClientResourceStatus.NACKED; + case TIMEOUT: + return ClientResourceStatus.TIMEOUT; default: throw new AssertionError("Unexpected ResourceMetadataStatus: " + status); } diff --git a/xds/src/main/java/io/grpc/xds/Endpoints.java b/xds/src/main/java/io/grpc/xds/Endpoints.java index 8b1715731df..558e3932ddc 100644 --- a/xds/src/main/java/io/grpc/xds/Endpoints.java +++ b/xds/src/main/java/io/grpc/xds/Endpoints.java @@ -21,6 +21,8 @@ import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.net.InetAddresses; import io.grpc.EquivalentAddressGroup; import java.net.InetSocketAddress; import java.util.List; @@ -41,11 +43,13 @@ abstract static class LocalityLbEndpoints { // Locality's priority level. abstract int priority(); + abstract ImmutableMap localityMetadata(); + static LocalityLbEndpoints create(List endpoints, int localityWeight, - int priority) { + int priority, ImmutableMap localityMetadata) { checkArgument(localityWeight > 0, "localityWeight must be greater than 0"); return new AutoValue_Endpoints_LocalityLbEndpoints( - ImmutableList.copyOf(endpoints), localityWeight, priority); + ImmutableList.copyOf(endpoints), localityWeight, priority, localityMetadata); } } @@ -55,23 +59,32 @@ abstract static class LbEndpoint { // The endpoint address to be connected to. abstract EquivalentAddressGroup eag(); - // Endpoint's weight for load balancing. If unspecified, value of 0 is returned. + // Endpoint's weight for load balancing. Guaranteed not to be 0. abstract int loadBalancingWeight(); // Whether the endpoint is healthy. abstract boolean isHealthy(); + abstract String hostname(); + + abstract ImmutableMap endpointMetadata(); + static LbEndpoint create(EquivalentAddressGroup eag, int loadBalancingWeight, - boolean isHealthy) { - return new AutoValue_Endpoints_LbEndpoint(eag, loadBalancingWeight, isHealthy); + boolean isHealthy, String hostname, ImmutableMap endpointMetadata) { + if (loadBalancingWeight == 0) { + loadBalancingWeight = 1; + } + return new AutoValue_Endpoints_LbEndpoint( + eag, loadBalancingWeight, isHealthy, hostname, endpointMetadata); } // Only for testing. @VisibleForTesting - static LbEndpoint create( - String address, int port, int loadBalancingWeight, boolean isHealthy) { - return LbEndpoint.create(new EquivalentAddressGroup(new InetSocketAddress(address, port)), - loadBalancingWeight, isHealthy); + static LbEndpoint create(String address, int port, int loadBalancingWeight, boolean isHealthy, + String hostname, ImmutableMap endpointMetadata) { + return LbEndpoint.create( + new EquivalentAddressGroup(new InetSocketAddress(InetAddresses.forString(address), port)), + loadBalancingWeight, isHealthy, hostname, endpointMetadata); } } diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index 978e6663cbe..01ef3d97b57 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -16,16 +16,18 @@ package io.grpc.xds; +import static com.google.common.base.Preconditions.checkNotNull; + import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.protobuf.util.Durations; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.grpc.Internal; import io.grpc.xds.client.EnvoyProtoData; import io.grpc.xds.internal.security.SslContextProviderSupplier; import java.net.InetAddress; -import java.net.UnknownHostException; import java.util.Objects; import javax.annotation.Nullable; @@ -41,13 +43,13 @@ private EnvoyServerProtoData() { } public abstract static class BaseTlsContext { - @Nullable protected final CommonTlsContext commonTlsContext; + protected final CommonTlsContext commonTlsContext; - protected BaseTlsContext(@Nullable CommonTlsContext commonTlsContext) { - this.commonTlsContext = commonTlsContext; + protected BaseTlsContext(CommonTlsContext commonTlsContext) { + this.commonTlsContext = checkNotNull(commonTlsContext, "commonTlsContext cannot be null."); } - @Nullable public CommonTlsContext getCommonTlsContext() { + public CommonTlsContext getCommonTlsContext() { return commonTlsContext; } @@ -71,20 +73,54 @@ public int hashCode() { public static final class UpstreamTlsContext extends BaseTlsContext { + private final String sni; + private final boolean autoHostSni; + private final boolean autoSniSanValidation; + @VisibleForTesting public UpstreamTlsContext(CommonTlsContext commonTlsContext) { super(commonTlsContext); + this.sni = null; + this.autoHostSni = false; + this.autoSniSanValidation = false; + } + + @VisibleForTesting + public UpstreamTlsContext( + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + upstreamTlsContext) { + super(upstreamTlsContext.getCommonTlsContext()); + this.sni = upstreamTlsContext.getSni(); + this.autoHostSni = upstreamTlsContext.getAutoHostSni(); + this.autoSniSanValidation = upstreamTlsContext.getAutoSniSanValidation(); } public static UpstreamTlsContext fromEnvoyProtoUpstreamTlsContext( io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext) { - return new UpstreamTlsContext(upstreamTlsContext.getCommonTlsContext()); + return new UpstreamTlsContext(upstreamTlsContext); + } + + public String getSni() { + return sni; + } + + public boolean getAutoHostSni() { + return autoHostSni; + } + + public boolean getAutoSniSanValidation() { + return autoSniSanValidation; } @Override public String toString() { - return "UpstreamTlsContext{" + "commonTlsContext=" + commonTlsContext + '}'; + return "UpstreamTlsContext{" + + "commonTlsContext=" + commonTlsContext + + "\nsni=" + sni + + "\nauto_host_sni=" + autoHostSni + + "\nauto_sni_san_validation=" + autoSniSanValidation + + "}"; } } @@ -148,9 +184,9 @@ abstract static class CidrRange { abstract int prefixLen(); - static CidrRange create(String addressPrefix, int prefixLen) throws UnknownHostException { + static CidrRange create(InetAddress addressPrefix, int prefixLen) { return new AutoValue_EnvoyServerProtoData_CidrRange( - InetAddress.getByName(addressPrefix), prefixLen); + addressPrefix, prefixLen); } } @@ -205,7 +241,7 @@ public static FilterChainMatch create(int destinationPort, @AutoValue abstract static class FilterChain { - // possibly empty + // Must be unique per server instance (except the default chain). abstract String name(); // TODO(sanjaypujare): flatten structure by moving FilterChainMatch class members here. @@ -247,13 +283,17 @@ abstract static class Listener { @Nullable abstract FilterChain defaultFilterChain(); + @Nullable + abstract Protocol protocol(); + static Listener create( String name, @Nullable String address, ImmutableList filterChains, - @Nullable FilterChain defaultFilterChain) { + @Nullable FilterChain defaultFilterChain, + @Nullable Protocol protocol) { return new AutoValue_EnvoyServerProtoData_Listener(name, address, filterChains, - defaultFilterChain); + defaultFilterChain, protocol); } } @@ -322,7 +362,7 @@ static OutlierDetection fromEnvoyOutlierDetection( Integer minimumHosts = envoyOutlierDetection.hasSuccessRateMinimumHosts() ? envoyOutlierDetection.getSuccessRateMinimumHosts().getValue() : null; Integer requestVolume = envoyOutlierDetection.hasSuccessRateRequestVolume() - ? envoyOutlierDetection.getSuccessRateMinimumHosts().getValue() : null; + ? envoyOutlierDetection.getSuccessRateRequestVolume().getValue() : null; successRateEjection = SuccessRateEjection.create(stdevFactor, enforcementPercentage, minimumHosts, requestVolume); diff --git a/xds/src/main/java/io/grpc/xds/ExtAuthzConfigParser.java b/xds/src/main/java/io/grpc/xds/ExtAuthzConfigParser.java new file mode 100644 index 00000000000..853e8a5c03a --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/ExtAuthzConfigParser.java @@ -0,0 +1,103 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.grpc.internal.GrpcUtil; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.internal.MatcherParser; +import io.grpc.xds.internal.extauthz.ExtAuthzConfig; +import io.grpc.xds.internal.extauthz.ExtAuthzParseException; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig; +import io.grpc.xds.internal.grpcservice.GrpcServiceParseException; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesParseException; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesParser; + + +/** + * Parser for {@link io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz}. + */ +final class ExtAuthzConfigParser { + + private ExtAuthzConfigParser() {} + + /** + * Parses the {@link io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz} proto to + * create an {@link ExtAuthzConfig} instance. + * + * @param extAuthzProto The ext_authz proto to parse. + * @return An {@link ExtAuthzConfig} instance. + * @throws ExtAuthzParseException if the proto is invalid or contains unsupported features. + */ + public static ExtAuthzConfig parse( + ExtAuthz extAuthzProto, BootstrapInfo bootstrapInfo, ServerInfo serverInfo) + throws ExtAuthzParseException { + if (!extAuthzProto.hasGrpcService()) { + throw new ExtAuthzParseException( + "unsupported ExtAuthz service type: only grpc_service is supported"); + } + GrpcServiceConfig grpcServiceConfig; + try { + grpcServiceConfig = + GrpcServiceConfigParser.parse(extAuthzProto.getGrpcService(), bootstrapInfo, serverInfo); + } catch (GrpcServiceParseException e) { + throw new ExtAuthzParseException("Failed to parse GrpcService config: " + e.getMessage(), e); + } + ExtAuthzConfig.Builder builder = ExtAuthzConfig.builder().grpcService(grpcServiceConfig) + .failureModeAllow(extAuthzProto.getFailureModeAllow()) + .failureModeAllowHeaderAdd(extAuthzProto.getFailureModeAllowHeaderAdd()) + .includePeerCertificate(extAuthzProto.getIncludePeerCertificate()) + .denyAtDisable(extAuthzProto.getDenyAtDisable().getDefaultValue().getValue()); + + if (extAuthzProto.hasFilterEnabled()) { + try { + builder.filterEnabled( + MatcherParser.parseFractionMatcher(extAuthzProto.getFilterEnabled().getDefaultValue())); + } catch (IllegalArgumentException e) { + throw new ExtAuthzParseException(e.getMessage()); + } + } + + if (extAuthzProto.hasStatusOnError()) { + builder.statusOnError( + GrpcUtil.httpStatusToGrpcStatus(extAuthzProto.getStatusOnError().getCodeValue())); + } + + if (extAuthzProto.hasAllowedHeaders()) { + builder.allowedHeaders(extAuthzProto.getAllowedHeaders().getPatternsList().stream() + .map(MatcherParser::parseStringMatcher).collect(ImmutableList.toImmutableList())); + } + + if (extAuthzProto.hasDisallowedHeaders()) { + builder.disallowedHeaders(extAuthzProto.getDisallowedHeaders().getPatternsList().stream() + .map(MatcherParser::parseStringMatcher).collect(ImmutableList.toImmutableList())); + } + + if (extAuthzProto.hasDecoderHeaderMutationRules()) { + try { + builder.decoderHeaderMutationRules( + HeaderMutationRulesParser.parse(extAuthzProto.getDecoderHeaderMutationRules())); + } catch (HeaderMutationRulesParseException e) { + throw new ExtAuthzParseException(e.getMessage(), e); + } + } + + return builder.build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/FaultFilter.java b/xds/src/main/java/io/grpc/xds/FaultFilter.java index d46b3d30f5a..0f3bb5b0557 100644 --- a/xds/src/main/java/io/grpc/xds/FaultFilter.java +++ b/xds/src/main/java/io/grpc/xds/FaultFilter.java @@ -37,7 +37,6 @@ import io.grpc.Deadline; import io.grpc.ForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; -import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -46,7 +45,6 @@ import io.grpc.internal.GrpcUtil; import io.grpc.xds.FaultConfig.FaultAbort; import io.grpc.xds.FaultConfig.FaultDelay; -import io.grpc.xds.Filter.ClientInterceptorBuilder; import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import java.util.Locale; import java.util.concurrent.Executor; @@ -57,10 +55,11 @@ import javax.annotation.Nullable; /** HttpFault filter implementation. */ -final class FaultFilter implements Filter, ClientInterceptorBuilder { +final class FaultFilter implements Filter { - static final FaultFilter INSTANCE = + private static final FaultFilter INSTANCE = new FaultFilter(ThreadSafeRandomImpl.instance, new AtomicLong()); + @VisibleForTesting static final Metadata.Key HEADER_DELAY_KEY = Metadata.Key.of("x-envoy-fault-delay-request", Metadata.ASCII_STRING_MARSHALLER); @@ -88,196 +87,216 @@ final class FaultFilter implements Filter, ClientInterceptorBuilder { this.activeFaultCounter = activeFaultCounter; } - @Override - public String[] typeUrls() { - return new String[] { TYPE_URL }; - } - - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - HTTPFault httpFaultProto; - if (!(rawProtoMessage instanceof Any)) { - return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + static final class Provider implements Filter.Provider { + @Override + public String[] typeUrls() { + return new String[]{TYPE_URL}; } - Any anyMessage = (Any) rawProtoMessage; - try { - httpFaultProto = anyMessage.unpack(HTTPFault.class); - } catch (InvalidProtocolBufferException e) { - return ConfigOrError.fromError("Invalid proto: " + e); + + @Override + public boolean isClientFilter() { + return true; } - return parseHttpFault(httpFaultProto); - } - private static ConfigOrError parseHttpFault(HTTPFault httpFault) { - FaultDelay faultDelay = null; - FaultAbort faultAbort = null; - if (httpFault.hasDelay()) { - faultDelay = parseFaultDelay(httpFault.getDelay()); + @Override + public FaultFilter newInstance(String name) { + return INSTANCE; } - if (httpFault.hasAbort()) { - ConfigOrError faultAbortOrError = parseFaultAbort(httpFault.getAbort()); - if (faultAbortOrError.errorDetail != null) { - return ConfigOrError.fromError( - "HttpFault contains invalid FaultAbort: " + faultAbortOrError.errorDetail); + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + HTTPFault httpFaultProto; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); } - faultAbort = faultAbortOrError.config; - } - Integer maxActiveFaults = null; - if (httpFault.hasMaxActiveFaults()) { - maxActiveFaults = httpFault.getMaxActiveFaults().getValue(); - if (maxActiveFaults < 0) { - maxActiveFaults = Integer.MAX_VALUE; + Any anyMessage = (Any) rawProtoMessage; + try { + httpFaultProto = anyMessage.unpack(HTTPFault.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); } + return parseHttpFault(httpFaultProto); } - return ConfigOrError.fromConfig(FaultConfig.create(faultDelay, faultAbort, maxActiveFaults)); - } - private static FaultDelay parseFaultDelay( - io.envoyproxy.envoy.extensions.filters.common.fault.v3.FaultDelay faultDelay) { - FaultConfig.FractionalPercent percent = parsePercent(faultDelay.getPercentage()); - if (faultDelay.hasHeaderDelay()) { - return FaultDelay.forHeader(percent); + @Override + public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { + return parseFilterConfig(rawProtoMessage); } - return FaultDelay.forFixedDelay(Durations.toNanos(faultDelay.getFixedDelay()), percent); - } - @VisibleForTesting - static ConfigOrError parseFaultAbort( - io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort faultAbort) { - FaultConfig.FractionalPercent percent = parsePercent(faultAbort.getPercentage()); - switch (faultAbort.getErrorTypeCase()) { - case HEADER_ABORT: - return ConfigOrError.fromConfig(FaultAbort.forHeader(percent)); - case HTTP_STATUS: - return ConfigOrError.fromConfig(FaultAbort.forStatus( - GrpcUtil.httpStatusToGrpcStatus(faultAbort.getHttpStatus()), percent)); - case GRPC_STATUS: - return ConfigOrError.fromConfig(FaultAbort.forStatus( - Status.fromCodeValue(faultAbort.getGrpcStatus()), percent)); - case ERRORTYPE_NOT_SET: - default: - return ConfigOrError.fromError( - "Unknown error type case: " + faultAbort.getErrorTypeCase()); + private static ConfigOrError parseHttpFault(HTTPFault httpFault) { + FaultDelay faultDelay = null; + FaultAbort faultAbort = null; + if (httpFault.hasDelay()) { + faultDelay = parseFaultDelay(httpFault.getDelay()); + } + if (httpFault.hasAbort()) { + ConfigOrError faultAbortOrError = parseFaultAbort(httpFault.getAbort()); + if (faultAbortOrError.errorDetail != null) { + return ConfigOrError.fromError( + "HttpFault contains invalid FaultAbort: " + faultAbortOrError.errorDetail); + } + faultAbort = faultAbortOrError.config; + } + Integer maxActiveFaults = null; + if (httpFault.hasMaxActiveFaults()) { + maxActiveFaults = httpFault.getMaxActiveFaults().getValue(); + if (maxActiveFaults < 0) { + maxActiveFaults = Integer.MAX_VALUE; + } + } + return ConfigOrError.fromConfig(FaultConfig.create(faultDelay, faultAbort, maxActiveFaults)); } - } - private static FaultConfig.FractionalPercent parsePercent(FractionalPercent proto) { - switch (proto.getDenominator()) { - case HUNDRED: - return FaultConfig.FractionalPercent.perHundred(proto.getNumerator()); - case TEN_THOUSAND: - return FaultConfig.FractionalPercent.perTenThousand(proto.getNumerator()); - case MILLION: - return FaultConfig.FractionalPercent.perMillion(proto.getNumerator()); - case UNRECOGNIZED: - default: - throw new IllegalArgumentException("Unknown denominator type: " + proto.getDenominator()); + private static FaultDelay parseFaultDelay( + io.envoyproxy.envoy.extensions.filters.common.fault.v3.FaultDelay faultDelay) { + FaultConfig.FractionalPercent percent = parsePercent(faultDelay.getPercentage()); + if (faultDelay.hasHeaderDelay()) { + return FaultDelay.forHeader(percent); + } + return FaultDelay.forFixedDelay(Durations.toNanos(faultDelay.getFixedDelay()), percent); } - } - @Override - public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { - return parseFilterConfig(rawProtoMessage); + @VisibleForTesting + static ConfigOrError parseFaultAbort( + io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort faultAbort) { + FaultConfig.FractionalPercent percent = parsePercent(faultAbort.getPercentage()); + switch (faultAbort.getErrorTypeCase()) { + case HEADER_ABORT: + return ConfigOrError.fromConfig(FaultAbort.forHeader(percent)); + case HTTP_STATUS: + return ConfigOrError.fromConfig(FaultAbort.forStatus( + GrpcUtil.httpStatusToGrpcStatus(faultAbort.getHttpStatus()), percent)); + case GRPC_STATUS: + return ConfigOrError.fromConfig(FaultAbort.forStatus( + Status.fromCodeValue(faultAbort.getGrpcStatus()), percent)); + case ERRORTYPE_NOT_SET: + default: + return ConfigOrError.fromError( + "Unknown error type case: " + faultAbort.getErrorTypeCase()); + } + } + + private static FaultConfig.FractionalPercent parsePercent(FractionalPercent proto) { + switch (proto.getDenominator()) { + case HUNDRED: + return FaultConfig.FractionalPercent.perHundred(proto.getNumerator()); + case TEN_THOUSAND: + return FaultConfig.FractionalPercent.perTenThousand(proto.getNumerator()); + case MILLION: + return FaultConfig.FractionalPercent.perMillion(proto.getNumerator()); + case UNRECOGNIZED: + default: + throw new IllegalArgumentException("Unknown denominator type: " + proto.getDenominator()); + } + } } @Nullable @Override public ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, + FilterConfig config, @Nullable FilterConfig overrideConfig, final ScheduledExecutorService scheduler) { checkNotNull(config, "config"); if (overrideConfig != null) { config = overrideConfig; } FaultConfig faultConfig = (FaultConfig) config; - Long delayNanos = null; - Status abortStatus = null; - if (faultConfig.maxActiveFaults() == null - || activeFaultCounter.get() < faultConfig.maxActiveFaults()) { - Metadata headers = args.getHeaders(); - if (faultConfig.faultDelay() != null) { - delayNanos = determineFaultDelayNanos(faultConfig.faultDelay(), headers); - } - if (faultConfig.faultAbort() != null) { - abortStatus = determineFaultAbortStatus(faultConfig.faultAbort(), headers); - } - } - if (delayNanos == null && abortStatus == null) { - return null; - } - final Long finalDelayNanos = delayNanos; - final Status finalAbortStatus = getAbortStatusWithDescription(abortStatus); final class FaultInjectionInterceptor implements ClientInterceptor { @Override public ClientCall interceptCall( final MethodDescriptor method, final CallOptions callOptions, final Channel next) { - Executor callExecutor = callOptions.getExecutor(); - if (callExecutor == null) { // This should never happen in practice because - // ManagedChannelImpl.ConfigSelectingClientCall always provides CallOptions with - // a callExecutor. - // TODO(https://github.com/grpc/grpc-java/issues/7868) - callExecutor = MoreExecutors.directExecutor(); + boolean checkFault = false; + if (faultConfig.maxActiveFaults() == null + || activeFaultCounter.get() < faultConfig.maxActiveFaults()) { + checkFault = faultConfig.faultDelay() != null || faultConfig.faultAbort() != null; } - if (finalDelayNanos != null) { - Supplier> callSupplier; - if (finalAbortStatus != null) { - callSupplier = Suppliers.ofInstance( - new FailingClientCall(finalAbortStatus, callExecutor)); - } else { - callSupplier = new Supplier>() { - @Override - public ClientCall get() { - return next.newCall(method, callOptions); - } - }; + if (!checkFault) { + return next.newCall(method, callOptions); + } + final class DeadlineInsightForwardingCall extends ForwardingClientCall { + private ClientCall delegate; + + @Override + protected ClientCall delegate() { + return delegate; } - final DelayInjectedCall delayInjectedCall = new DelayInjectedCall<>( - finalDelayNanos, callExecutor, scheduler, callOptions.getDeadline(), callSupplier); - final class DeadlineInsightForwardingCall extends ForwardingClientCall { - @Override - protected ClientCall delegate() { - return delayInjectedCall; + @Override + public void start(Listener listener, Metadata headers) { + Executor callExecutor = callOptions.getExecutor(); + if (callExecutor == null) { // This should never happen in practice because + // ManagedChannelImpl.ConfigSelectingClientCall always provides CallOptions with + // a callExecutor. + // TODO(https://github.com/grpc/grpc-java/issues/7868) + callExecutor = MoreExecutors.directExecutor(); } - @Override - public void start(Listener listener, Metadata headers) { - Listener finalListener = - new SimpleForwardingClientCallListener(listener) { - @Override - public void onClose(Status status, Metadata trailers) { - if (status.getCode().equals(Code.DEADLINE_EXCEEDED)) { - // TODO(zdapeng:) check effective deadline locally, and - // do the following only if the local deadline is exceeded. - // (If the server sends DEADLINE_EXCEEDED for its own deadline, then the - // injected delay does not contribute to the error, because the request is - // only sent out after the delay. There could be a race between local and - // remote, but it is rather rare.) - String description = String.format( - Locale.US, - "Deadline exceeded after up to %d ns of fault-injected delay", - finalDelayNanos); - if (status.getDescription() != null) { - description = description + ": " + status.getDescription(); - } - status = Status.DEADLINE_EXCEEDED - .withDescription(description).withCause(status.getCause()); - // Replace trailers to prevent mixing sources of status and trailers. - trailers = new Metadata(); + Long delayNanos; + Status abortStatus = null; + if (faultConfig.faultDelay() != null) { + delayNanos = determineFaultDelayNanos(faultConfig.faultDelay(), headers); + } else { + delayNanos = null; + } + if (faultConfig.faultAbort() != null) { + abortStatus = getAbortStatusWithDescription( + determineFaultAbortStatus(faultConfig.faultAbort(), headers)); + } + + Supplier> callSupplier; + if (abortStatus != null) { + callSupplier = Suppliers.ofInstance( + new FailingClientCall(abortStatus, callExecutor)); + } else { + callSupplier = new Supplier>() { + @Override + public ClientCall get() { + return next.newCall(method, callOptions); + } + }; + } + if (delayNanos == null) { + delegate = callSupplier.get(); + delegate().start(listener, headers); + return; + } + + delegate = new DelayInjectedCall<>( + delayNanos, callExecutor, scheduler, callOptions.getDeadline(), callSupplier); + + Listener finalListener = + new SimpleForwardingClientCallListener(listener) { + @Override + public void onClose(Status status, Metadata trailers) { + if (status.getCode().equals(Code.DEADLINE_EXCEEDED)) { + // TODO(zdapeng:) check effective deadline locally, and + // do the following only if the local deadline is exceeded. + // (If the server sends DEADLINE_EXCEEDED for its own deadline, then the + // injected delay does not contribute to the error, because the request is + // only sent out after the delay. There could be a race between local and + // remote, but it is rather rare.) + String description = String.format( + Locale.US, + "Deadline exceeded after up to %d ns of fault-injected delay", + delayNanos); + if (status.getDescription() != null) { + description = description + ": " + status.getDescription(); } - delegate().onClose(status, trailers); + status = Status.DEADLINE_EXCEEDED + .withDescription(description).withCause(status.getCause()); + // Replace trailers to prevent mixing sources of status and trailers. + trailers = new Metadata(); } - }; - delegate().start(finalListener, headers); - } + delegate().onClose(status, trailers); + } + }; + delegate().start(finalListener, headers); } - - return new DeadlineInsightForwardingCall(); - } else { - return new FailingClientCall<>(finalAbortStatus, callExecutor); } + + return new DeadlineInsightForwardingCall(); } } diff --git a/xds/src/main/java/io/grpc/xds/Filter.java b/xds/src/main/java/io/grpc/xds/Filter.java index 4b2767687f3..416d929becf 100644 --- a/xds/src/main/java/io/grpc/xds/Filter.java +++ b/xds/src/main/java/io/grpc/xds/Filter.java @@ -19,57 +19,112 @@ import com.google.common.base.MoreObjects; import com.google.protobuf.Message; import io.grpc.ClientInterceptor; -import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.ServerInterceptor; +import java.io.Closeable; import java.util.Objects; import java.util.concurrent.ScheduledExecutorService; import javax.annotation.Nullable; /** - * Defines the parsing functionality of an HTTP filter. A Filter may optionally implement either - * {@link ClientInterceptorBuilder} or {@link ServerInterceptorBuilder} or both, indicating it is - * capable of working on the client side or server side or both, respectively. + * Defines the parsing functionality of an HTTP filter. + * + *

A Filter may optionally implement either {@link Filter#buildClientInterceptor} or + * {@link Filter#buildServerInterceptor} or both, and return true from corresponding + * {@link Provider#isClientFilter()}, {@link Provider#isServerFilter()} to indicate that the filter + * is capable of working on the client side or server side or both, respectively. */ -interface Filter { +interface Filter extends Closeable { - /** - * The proto message types supported by this filter. A filter will be registered by each of its - * supported message types. - */ - String[] typeUrls(); + /** Represents an opaque data structure holding configuration for a filter. */ + interface FilterConfig { + String typeUrl(); + } /** - * Parses the top-level filter config from raw proto message. The message may be either a {@link - * com.google.protobuf.Any} or a {@link com.google.protobuf.Struct}. + * Common interface for filter providers. */ - ConfigOrError parseFilterConfig(Message rawProtoMessage); + interface Provider { + /** + * The proto message types supported by this filter. A filter will be registered by each of its + * supported message types. + */ + String[] typeUrls(); - /** - * Parses the per-filter override filter config from raw proto message. The message may be either - * a {@link com.google.protobuf.Any} or a {@link com.google.protobuf.Struct}. - */ - ConfigOrError parseFilterConfigOverride(Message rawProtoMessage); + /** + * Whether the filter can be installed on the client side. + * + *

Returns true if the filter implements {@link Filter#buildClientInterceptor}. + */ + default boolean isClientFilter() { + return false; + } - /** Represents an opaque data structure holding configuration for a filter. */ - interface FilterConfig { - String typeUrl(); + /** + * Whether the filter can be installed into xDS-enabled servers. + * + *

Returns true if the filter implements {@link Filter#buildServerInterceptor}. + */ + default boolean isServerFilter() { + return false; + } + + /** + * Creates a new instance of the filter. + * + *

Returns a filter instance registered with the same typeUrls as the provider, + * capable of working with the same FilterConfig type returned by provider's parse functions. + * + *

For xDS gRPC clients, new filter instances are created per combination of: + *

    + *
  1. XdsNameResolver instance,
  2. + *
  3. Filter name+typeUrl in HttpConnectionManager (HCM) http_filters.
  4. + *
+ * + *

For xDS-enabled gRPC servers, new filter instances are created per combination of: + *

    + *
  1. Server instance,
  2. + *
  3. FilterChain name,
  4. + *
  5. Filter name+typeUrl in FilterChain's HCM.http_filters.
  6. + *
+ */ + Filter newInstance(String name); + + /** + * Parses the top-level filter config from raw proto message. The message may be either a {@link + * com.google.protobuf.Any} or a {@link com.google.protobuf.Struct}. + */ + ConfigOrError parseFilterConfig(Message rawProtoMessage); + + /** + * Parses the per-filter override filter config from raw proto message. The message may be + * either a {@link com.google.protobuf.Any} or a {@link com.google.protobuf.Struct}. + */ + ConfigOrError parseFilterConfigOverride(Message rawProtoMessage); } /** Uses the FilterConfigs produced above to produce an HTTP filter interceptor for clients. */ - interface ClientInterceptorBuilder { - @Nullable - ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, - ScheduledExecutorService scheduler); + @Nullable + default ClientInterceptor buildClientInterceptor( + FilterConfig config, @Nullable FilterConfig overrideConfig, + ScheduledExecutorService scheduler) { + return null; } /** Uses the FilterConfigs produced above to produce an HTTP filter interceptor for the server. */ - interface ServerInterceptorBuilder { - @Nullable - ServerInterceptor buildServerInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig); + @Nullable + default ServerInterceptor buildServerInterceptor( + FilterConfig config, @Nullable FilterConfig overrideConfig) { + return null; } + /** + * Releases filter resources like shared resources and remote connections. + * + *

See {@link Provider#newInstance()} for details on filter instance creation. + */ + @Override + default void close() {} + /** Filter config with instance name. */ final class NamedFilterConfig { // filter instance name @@ -81,6 +136,10 @@ final class NamedFilterConfig { this.filterConfig = filterConfig; } + String filterStateKey() { + return name + "_" + filterConfig.typeUrl(); + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java index 37a0e6a8ae0..77a66495614 100644 --- a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java @@ -17,8 +17,8 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.InternalXdsAttributes.ATTR_DRAIN_GRACE_NANOS; -import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; +import static io.grpc.xds.XdsAttributes.ATTR_DRAIN_GRACE_NANOS; +import static io.grpc.xds.XdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; import static io.grpc.xds.internal.security.SecurityProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; @@ -151,6 +151,10 @@ static final class FilterChainSelector { this.defaultRoutingConfig = checkNotNull(defaultRoutingConfig, "defaultRoutingConfig"); } + FilterChainSelector(Map> routingConfigs) { + this(routingConfigs, null, new AtomicReference<>()); + } + @VisibleForTesting Map> getRoutingConfigs() { return routingConfigs; diff --git a/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java b/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java index 4295d75f59b..b3cc14c6484 100644 --- a/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java +++ b/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java @@ -18,11 +18,11 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import java.util.Comparator; import java.util.TreeSet; import java.util.concurrent.atomic.AtomicLong; -import javax.annotation.concurrent.GuardedBy; /** * Maintains the current xDS selector and any resources using that selector. When the selector diff --git a/xds/src/main/java/io/grpc/xds/FilterRegistry.java b/xds/src/main/java/io/grpc/xds/FilterRegistry.java index 7f1fe82c6c3..da3a59fe8c1 100644 --- a/xds/src/main/java/io/grpc/xds/FilterRegistry.java +++ b/xds/src/main/java/io/grpc/xds/FilterRegistry.java @@ -23,21 +23,22 @@ /** * A registry for all supported {@link Filter}s. Filters can be queried from the registry - * by any of the {@link Filter#typeUrls() type URLs}. + * by any of the {@link Filter.Provider#typeUrls() type URLs}. */ final class FilterRegistry { private static FilterRegistry instance; - private final Map supportedFilters = new HashMap<>(); + private final Map supportedFilters = new HashMap<>(); private FilterRegistry() {} static synchronized FilterRegistry getDefaultRegistry() { if (instance == null) { instance = newRegistry().register( - FaultFilter.INSTANCE, - RouterFilter.INSTANCE, - RbacFilter.INSTANCE); + new FaultFilter.Provider(), + new RouterFilter.Provider(), + new RbacFilter.Provider(), + new GcpAuthenticationFilter.Provider()); } return instance; } @@ -48,8 +49,8 @@ static FilterRegistry newRegistry() { } @VisibleForTesting - FilterRegistry register(Filter... filters) { - for (Filter filter : filters) { + FilterRegistry register(Filter.Provider... filters) { + for (Filter.Provider filter : filters) { for (String typeUrl : filter.typeUrls()) { supportedFilters.put(typeUrl, filter); } @@ -58,7 +59,7 @@ FilterRegistry register(Filter... filters) { } @Nullable - Filter get(String typeUrl) { + Filter.Provider get(String typeUrl) { return supportedFilters.get(typeUrl); } } diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java new file mode 100644 index 00000000000..8ec02f4f809 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java @@ -0,0 +1,327 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.XdsNameResolver.CLUSTER_SELECTION_KEY; +import static io.grpc.xds.XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY; + +import com.google.auth.oauth2.ComputeEngineCredentials; +import com.google.auth.oauth2.IdTokenCredentials; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.primitives.UnsignedLongs; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.Audience; +import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig; +import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.TokenCacheConfig; +import io.grpc.CallCredentials; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.CompositeCallCredentials; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.auth.MoreCallCredentials; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; +import io.grpc.xds.MetadataRegistry.MetadataValueParser; +import io.grpc.xds.XdsConfig.XdsClusterConfig; +import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Function; +import javax.annotation.Nullable; + +/** + * A {@link Filter} that injects a {@link CallCredentials} to handle + * authentication for xDS credentials. + */ +final class GcpAuthenticationFilter implements Filter { + + static final String TYPE_URL = + "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig"; + private final LruCache callCredentialsCache; + final String filterInstanceName; + + GcpAuthenticationFilter(String name, int cacheSize) { + filterInstanceName = checkNotNull(name, "name"); + this.callCredentialsCache = new LruCache<>(cacheSize); + } + + static final class Provider implements Filter.Provider { + private final int cacheSize = 10; + + @Override + public String[] typeUrls() { + return new String[]{TYPE_URL}; + } + + @Override + public boolean isClientFilter() { + return true; + } + + @Override + public GcpAuthenticationFilter newInstance(String name) { + return new GcpAuthenticationFilter(name, cacheSize); + } + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + GcpAuthnFilterConfig gcpAuthnProto; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + } + Any anyMessage = (Any) rawProtoMessage; + + try { + gcpAuthnProto = anyMessage.unpack(GcpAuthnFilterConfig.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + + long cacheSize = 10; + // Validate cache_config + if (gcpAuthnProto.hasCacheConfig()) { + TokenCacheConfig cacheConfig = gcpAuthnProto.getCacheConfig(); + if (cacheConfig.hasCacheSize()) { + cacheSize = cacheConfig.getCacheSize().getValue(); + if (cacheSize == 0) { + return ConfigOrError.fromError( + "cache_config.cache_size must be greater than zero"); + } + } + + // LruCache's size is an int and briefly exceeds its maximum size before evicting entries + cacheSize = UnsignedLongs.min(cacheSize, Integer.MAX_VALUE - 1); + } + + GcpAuthenticationConfig config = new GcpAuthenticationConfig((int) cacheSize); + return ConfigOrError.fromConfig(config); + } + + @Override + public ConfigOrError parseFilterConfigOverride( + Message rawProtoMessage) { + return parseFilterConfig(rawProtoMessage); + } + } + + @Nullable + @Override + public ClientInterceptor buildClientInterceptor(FilterConfig config, + @Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) { + + ComputeEngineCredentials credentials = ComputeEngineCredentials.create(); + synchronized (callCredentialsCache) { + callCredentialsCache.resizeCache(((GcpAuthenticationConfig) config).getCacheSize()); + } + return new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + + String clusterName = callOptions.getOption(CLUSTER_SELECTION_KEY); + if (clusterName == null) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format( + "GCP Authn for %s does not contain cluster resource", filterInstanceName))); + } + + if (!clusterName.startsWith("cluster:")) { + return next.newCall(method, callOptions); + } + XdsConfig xdsConfig = callOptions.getOption(XDS_CONFIG_CALL_OPTION_KEY); + if (xdsConfig == null) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format( + "GCP Authn for %s with %s does not contain xds configuration", + filterInstanceName, clusterName))); + } + StatusOr xdsCluster = + xdsConfig.getClusters().get(clusterName.substring("cluster:".length())); + if (xdsCluster == null) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format( + "GCP Authn for %s with %s - xds cluster config does not contain xds cluster", + filterInstanceName, clusterName))); + } + if (!xdsCluster.hasValue()) { + return new FailingClientCall<>(xdsCluster.getStatus()); + } + Object audienceObj = + xdsCluster.getValue().getClusterResource().parsedMetadata().get(filterInstanceName); + if (audienceObj == null) { + return next.newCall(method, callOptions); + } + if (!(audienceObj instanceof AudienceWrapper)) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format("GCP Authn found wrong type in %s metadata: %s=%s", + clusterName, filterInstanceName, audienceObj.getClass()))); + } + AudienceWrapper audience = (AudienceWrapper) audienceObj; + CallCredentials existingCallCredentials = callOptions.getCredentials(); + CallCredentials newCallCredentials = + getCallCredentials(callCredentialsCache, audience.audience, credentials); + if (existingCallCredentials != null) { + callOptions = callOptions.withCallCredentials( + new CompositeCallCredentials(existingCallCredentials, newCallCredentials)); + } else { + callOptions = callOptions.withCallCredentials(newCallCredentials); + } + return next.newCall(method, callOptions); + } + }; + } + + private CallCredentials getCallCredentials(LruCache cache, + String audience, ComputeEngineCredentials credentials) { + + synchronized (cache) { + return cache.getOrInsert(audience, key -> { + IdTokenCredentials creds = IdTokenCredentials.newBuilder() + .setIdTokenProvider(credentials) + .setTargetAudience(audience) + .build(); + return MoreCallCredentials.from(creds); + }); + } + } + + static final class GcpAuthenticationConfig implements FilterConfig { + + private final int cacheSize; + + public GcpAuthenticationConfig(int cacheSize) { + this.cacheSize = cacheSize; + } + + public int getCacheSize() { + return cacheSize; + } + + @Override + public String typeUrl() { + return GcpAuthenticationFilter.TYPE_URL; + } + } + + /** An implementation of {@link ClientCall} that fails when started. */ + @VisibleForTesting + static final class FailingClientCall extends ClientCall { + + @VisibleForTesting + final Status error; + + public FailingClientCall(Status error) { + this.error = error; + } + + @Override + public void start(ClientCall.Listener listener, Metadata headers) { + listener.onClose(error, new Metadata()); + } + + @Override + public void request(int numMessages) {} + + @Override + public void cancel(String message, Throwable cause) {} + + @Override + public void halfClose() {} + + @Override + public void sendMessage(ReqT message) {} + } + + private static final class LruCache { + + private Map cache; + private int maxSize; + + LruCache(int maxSize) { + this.maxSize = maxSize; + this.cache = createEvictingMap(maxSize); + } + + V getOrInsert(K key, Function create) { + return cache.computeIfAbsent(key, create); + } + + private void resizeCache(int newSize) { + if (newSize >= maxSize) { + maxSize = newSize; + return; + } + Map newCache = createEvictingMap(newSize); + maxSize = newSize; + newCache.putAll(cache); + cache = newCache; + } + + private Map createEvictingMap(int size) { + return new LinkedHashMap(size, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > LruCache.this.maxSize; + } + }; + } + } + + static class AudienceMetadataParser implements MetadataValueParser { + + static final class AudienceWrapper { + final String audience; + + AudienceWrapper(String audience) { + this.audience = checkNotNull(audience); + } + } + + @Override + public String getTypeUrl() { + return "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.Audience"; + } + + @Override + public AudienceWrapper parse(Any any) throws ResourceInvalidException { + Audience audience; + try { + audience = any.unpack(Audience.class); + } catch (InvalidProtocolBufferException ex) { + throw new ResourceInvalidException("Invalid Resource in address proto", ex); + } + String url = audience.getUrl(); + if (url.isEmpty()) { + throw new ResourceInvalidException( + "Audience URL is empty. Metadata value must contain a valid URL."); + } + return new AudienceWrapper(url); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/GrpcBootstrapImplConfig.java b/xds/src/main/java/io/grpc/xds/GrpcBootstrapImplConfig.java new file mode 100644 index 00000000000..e119321fb6c --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/GrpcBootstrapImplConfig.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.auto.value.AutoValue; +import io.grpc.Internal; +import io.grpc.xds.client.AllowedGrpcServices; + +/** + * Custom configuration for gRPC xDS bootstrap implementation. + */ +@Internal +@AutoValue +public abstract class GrpcBootstrapImplConfig { + public abstract AllowedGrpcServices allowedGrpcServices(); + + public static GrpcBootstrapImplConfig create(AllowedGrpcServices services) { + return new AutoValue_GrpcBootstrapImplConfig(services); + } +} diff --git a/xds/src/main/java/io/grpc/xds/GrpcBootstrapperImpl.java b/xds/src/main/java/io/grpc/xds/GrpcBootstrapperImpl.java index f61fab42cae..00a2e0d48d6 100644 --- a/xds/src/main/java/io/grpc/xds/GrpcBootstrapperImpl.java +++ b/xds/src/main/java/io/grpc/xds/GrpcBootstrapperImpl.java @@ -18,14 +18,21 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.CallCredentials; import io.grpc.ChannelCredentials; import io.grpc.internal.JsonUtil; +import io.grpc.xds.client.AllowedGrpcServices; +import io.grpc.xds.client.AllowedGrpcServices.AllowedGrpcService; import io.grpc.xds.client.BootstrapperImpl; +import io.grpc.xds.client.ConfiguredChannelCredentials; +import io.grpc.xds.client.ConfiguredChannelCredentials.ChannelCredsConfig; import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.client.XdsLogger; import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.Optional; import javax.annotation.Nullable; class GrpcBootstrapperImpl extends BootstrapperImpl { @@ -48,7 +55,11 @@ class GrpcBootstrapperImpl extends BootstrapperImpl { @Override public BootstrapInfo bootstrap(Map rawData) throws XdsInitializationException { - return super.bootstrap(rawData); + BootstrapInfo info = super.bootstrap(rawData); + if (info.servers().isEmpty()) { + throw new XdsInitializationException("Invalid bootstrap: 'xds_servers' is empty"); + } + return info; } /** @@ -92,29 +103,50 @@ protected String getJsonContent() throws XdsInitializationException, IOException @Override protected Object getImplSpecificConfig(Map serverConfig, String serverUri) throws XdsInitializationException { - return getChannelCredentials(serverConfig, serverUri); + ConfiguredChannelCredentials configuredChannel = getChannelCredentials(serverConfig, serverUri); + return configuredChannel != null ? configuredChannel.channelCredentials() : null; + } + + @GuardedBy("GrpcBootstrapperImpl.class") + private static Map defaultBootstrapOverride; + @GuardedBy("GrpcBootstrapperImpl.class") + private static BootstrapInfo defaultBootstrap; + + static synchronized void setDefaultBootstrapOverride(Map rawBootstrap) { + defaultBootstrapOverride = rawBootstrap; + } + + static synchronized BootstrapInfo defaultBootstrap() throws XdsInitializationException { + if (defaultBootstrap == null) { + if (defaultBootstrapOverride == null) { + defaultBootstrap = new GrpcBootstrapperImpl().bootstrap(); + } else { + defaultBootstrap = new GrpcBootstrapperImpl().bootstrap(defaultBootstrapOverride); + } + } + return defaultBootstrap; } - private static ChannelCredentials getChannelCredentials(Map serverConfig, - String serverUri) + private static ConfiguredChannelCredentials getChannelCredentials(Map serverConfig, + String serverUri) throws XdsInitializationException { List rawChannelCredsList = JsonUtil.getList(serverConfig, "channel_creds"); if (rawChannelCredsList == null || rawChannelCredsList.isEmpty()) { throw new XdsInitializationException( "Invalid bootstrap: server " + serverUri + " 'channel_creds' required"); } - ChannelCredentials channelCredentials = + ConfiguredChannelCredentials credentials = parseChannelCredentials(JsonUtil.checkObjectList(rawChannelCredsList), serverUri); - if (channelCredentials == null) { + if (credentials == null) { throw new XdsInitializationException( "Server " + serverUri + ": no supported channel credentials found"); } - return channelCredentials; + return credentials; } @Nullable - private static ChannelCredentials parseChannelCredentials(List> jsonList, - String serverUri) + private static ConfiguredChannelCredentials parseChannelCredentials(List> jsonList, + String serverUri) throws XdsInitializationException { for (Map channelCreds : jsonList) { String type = JsonUtil.getString(channelCreds, "type"); @@ -130,9 +162,95 @@ private static ChannelCredentials parseChannelCredentials(List> j config = ImmutableMap.of(); } - return provider.newChannelCredentials(config); + ChannelCredentials creds = provider.newChannelCredentials(config); + if (creds == null) { + return null; + } + return ConfiguredChannelCredentials.create(creds, new JsonChannelCredsConfig(type, config)); } } return null; } + + @Override + protected Optional parseImplSpecificObject( + @Nullable Map rawAllowedGrpcServices) + throws XdsInitializationException { + if (rawAllowedGrpcServices == null || rawAllowedGrpcServices.isEmpty()) { + return Optional.of(GrpcBootstrapImplConfig.create(AllowedGrpcServices.empty())); + } + + ImmutableMap.Builder builder = + ImmutableMap.builder(); + for (String targetUri : rawAllowedGrpcServices.keySet()) { + Map serviceConfig = JsonUtil.getObject(rawAllowedGrpcServices, targetUri); + if (serviceConfig == null) { + throw new XdsInitializationException( + "Invalid allowed_grpc_services config for " + targetUri); + } + ConfiguredChannelCredentials configuredChannel = + getChannelCredentials(serviceConfig, targetUri); + + Optional callCredentials = Optional.empty(); + List rawCallCredsList = JsonUtil.getList(serviceConfig, "call_creds"); + if (rawCallCredsList != null && !rawCallCredsList.isEmpty()) { + callCredentials = + parseCallCredentials(JsonUtil.checkObjectList(rawCallCredsList), targetUri); + } + + AllowedGrpcService.Builder b = AllowedGrpcService.builder() + .configuredChannelCredentials(configuredChannel); + callCredentials.ifPresent(b::callCredentials); + builder.put(targetUri, b.build()); + } + GrpcBootstrapImplConfig customConfig = + GrpcBootstrapImplConfig.create(AllowedGrpcServices.create(builder.build())); + return Optional.of(customConfig); + } + + @SuppressWarnings("unused") + private static Optional parseCallCredentials(List> jsonList, + String targetUri) + throws XdsInitializationException { + // TODO(sauravzg): Currently no xDS call credentials providers are implemented (no + // XdsCallCredentialsRegistry). + // As per A102/A97, we should just ignore unsupported call credentials types + // without throwing an exception. + return Optional.empty(); + } + + private static final class JsonChannelCredsConfig implements ChannelCredsConfig { + private final String type; + private final Map config; + + JsonChannelCredsConfig(String type, Map config) { + this.type = type; + this.config = config; + } + + @Override + public String type() { + return type; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + JsonChannelCredsConfig that = (JsonChannelCredsConfig) o; + return java.util.Objects.equals(type, that.type) + && java.util.Objects.equals(config, that.config); + } + + @Override + public int hashCode() { + return java.util.Objects.hash(type, config); + } + } + } + diff --git a/xds/src/main/java/io/grpc/xds/GrpcServiceConfigParser.java b/xds/src/main/java/io/grpc/xds/GrpcServiceConfigParser.java new file mode 100644 index 00000000000..1510924f74c --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/GrpcServiceConfigParser.java @@ -0,0 +1,339 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.OAuth2Credentials; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.Durations; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.xds.v3.XdsCredentials; +import io.grpc.CallCredentials; +import io.grpc.CompositeCallCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.Metadata; +import io.grpc.NameResolverRegistry; +import io.grpc.SecurityLevel; +import io.grpc.alts.GoogleDefaultChannelCredentials; +import io.grpc.auth.MoreCallCredentials; +import io.grpc.xds.client.AllowedGrpcServices; +import io.grpc.xds.client.AllowedGrpcServices.AllowedGrpcService; +import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.ConfiguredChannelCredentials; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig; +import io.grpc.xds.internal.grpcservice.GrpcServiceParseException; +import io.grpc.xds.internal.grpcservice.HeaderValue; +import io.grpc.xds.internal.grpcservice.HeaderValueValidationUtils; +import java.net.URI; +import java.net.URISyntaxException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.Executor; + +/** + * Parser for {@link io.envoyproxy.envoy.config.core.v3.GrpcService} and related protos. + */ +final class GrpcServiceConfigParser { + + static final String TLS_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "tls.v3.TlsCredentials"; + static final String LOCAL_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "local.v3.LocalCredentials"; + static final String XDS_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "xds.v3.XdsCredentials"; + static final String INSECURE_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "insecure.v3.InsecureCredentials"; + static final String GOOGLE_DEFAULT_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "google_default.v3.GoogleDefaultCredentials"; + + + + /** + * Parses the {@link io.envoyproxy.envoy.config.core.v3.GrpcService} proto to create a + * {@link GrpcServiceConfig} instance. + * + * @param grpcServiceProto The proto to parse. + * @return A {@link GrpcServiceConfig} instance. + * @throws GrpcServiceParseException if the proto is invalid or uses unsupported features. + */ + public static GrpcServiceConfig parse(GrpcService grpcServiceProto, + Bootstrapper.BootstrapInfo bootstrapInfo, Bootstrapper.ServerInfo serverInfo) + throws GrpcServiceParseException { + if (!grpcServiceProto.hasGoogleGrpc()) { + throw new GrpcServiceParseException( + "Unsupported: GrpcService must have GoogleGrpc, got: " + grpcServiceProto); + } + GrpcServiceConfig.GoogleGrpcConfig googleGrpcConfig = + parseGoogleGrpcConfig(grpcServiceProto.getGoogleGrpc(), bootstrapInfo, serverInfo); + + GrpcServiceConfig.Builder builder = GrpcServiceConfig.builder().googleGrpc(googleGrpcConfig); + + ImmutableList.Builder initialMetadata = ImmutableList.builder(); + for (io.envoyproxy.envoy.config.core.v3.HeaderValue header : grpcServiceProto + .getInitialMetadataList()) { + String key = header.getKey(); + HeaderValue headerValue; + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + headerValue = HeaderValue.create(key, header.getRawValue()); + } else { + headerValue = HeaderValue.create(key, header.getValue()); + } + if (HeaderValueValidationUtils.isDisallowed(headerValue)) { + throw new GrpcServiceParseException("Invalid initial metadata header: " + key); + } + initialMetadata.add(headerValue); + } + builder.initialMetadata(initialMetadata.build()); + + if (grpcServiceProto.hasTimeout()) { + com.google.protobuf.Duration timeout = grpcServiceProto.getTimeout(); + if (!Durations.isValid(timeout) || Durations.compare(timeout, Durations.ZERO) <= 0) { + throw new GrpcServiceParseException("Timeout must be strictly positive and valid"); + } + builder.timeout(Duration.ofSeconds(timeout.getSeconds(), timeout.getNanos())); + } + return builder.build(); + } + + /** + * Parses the {@link io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc} proto to create a + * {@link GrpcServiceConfig.GoogleGrpcConfig} instance. + * + * @param googleGrpcProto The proto to parse. + * @return A {@link GrpcServiceConfig.GoogleGrpcConfig} instance. + * @throws GrpcServiceParseException if the proto is invalid. + */ + public static GrpcServiceConfig.GoogleGrpcConfig parseGoogleGrpcConfig( + GrpcService.GoogleGrpc googleGrpcProto, Bootstrapper.BootstrapInfo bootstrapInfo, + Bootstrapper.ServerInfo serverInfo) throws GrpcServiceParseException { + + String targetUri = googleGrpcProto.getTargetUri(); + + AllowedGrpcServices allowedGrpcServices = + bootstrapInfo.implSpecificObject() + .filter(GrpcBootstrapImplConfig.class::isInstance) + .map(GrpcBootstrapImplConfig.class::cast) + .map(GrpcBootstrapImplConfig::allowedGrpcServices) + .orElse(AllowedGrpcServices.empty()); + + boolean isTrustedControlPlane = serverInfo.isTrustedXdsServer(); + Optional override = + Optional.ofNullable(allowedGrpcServices.services().get(targetUri)); + + boolean isTargetUriSchemeSupported = false; + try { + URI uri = new URI(targetUri); + String scheme = uri.getScheme(); + if (scheme == null) { + scheme = NameResolverRegistry.getDefaultRegistry().getDefaultScheme(); + } + if (scheme != null) { + isTargetUriSchemeSupported = + NameResolverRegistry.getDefaultRegistry().getProviderForScheme(scheme) != null; + } + } catch (URISyntaxException e) { + // Fallback or ignore if not a valid URI + } + + if (!isTargetUriSchemeSupported) { + throw new GrpcServiceParseException("Target URI scheme is not resolvable: " + targetUri); + } + + if (!isTrustedControlPlane) { + if (!override.isPresent()) { + throw new GrpcServiceParseException( + "Untrusted xDS server & URI not found in allowed_grpc_services: " + targetUri); + } + + GrpcServiceConfig.GoogleGrpcConfig.Builder builder = + GrpcServiceConfig.GoogleGrpcConfig.builder().target(targetUri) + .configuredChannelCredentials(override.get().configuredChannelCredentials()); + if (override.get().callCredentials().isPresent()) { + builder.callCredentials(override.get().callCredentials().get()); + } + return builder.build(); + } + + ConfiguredChannelCredentials channelCreds = + extractChannelCredentials(googleGrpcProto.getChannelCredentialsPluginList()); + + Optional callCreds = + extractCallCredentials(googleGrpcProto.getCallCredentialsPluginList()); + + GrpcServiceConfig.GoogleGrpcConfig.Builder builder = + GrpcServiceConfig.GoogleGrpcConfig.builder().target(googleGrpcProto.getTargetUri()) + .configuredChannelCredentials(channelCreds); + if (callCreds.isPresent()) { + builder.callCredentials(callCreds.get()); + } + return builder.build(); + } + + private static Optional channelCredsFromProto(Any cred) + throws GrpcServiceParseException { + String typeUrl = cred.getTypeUrl(); + try { + switch (typeUrl) { + case GOOGLE_DEFAULT_CREDENTIALS_TYPE_URL: + return Optional + .of(ConfiguredChannelCredentials.create(GoogleDefaultChannelCredentials.create(), + new ProtoChannelCredsConfig(typeUrl, cred))); + case INSECURE_CREDENTIALS_TYPE_URL: + return Optional.of(ConfiguredChannelCredentials.create( + InsecureChannelCredentials.create(), new ProtoChannelCredsConfig(typeUrl, cred))); + case XDS_CREDENTIALS_TYPE_URL: + XdsCredentials xdsConfig = cred.unpack(XdsCredentials.class); + Optional fallbackCreds = + channelCredsFromProto(xdsConfig.getFallbackCredentials()); + if (!fallbackCreds.isPresent()) { + throw new GrpcServiceParseException( + "Unsupported fallback credentials type for XdsCredentials"); + } + return Optional.of(ConfiguredChannelCredentials.create( + XdsChannelCredentials.create(fallbackCreds.get().channelCredentials()), + new ProtoChannelCredsConfig(typeUrl, cred))); + case LOCAL_CREDENTIALS_TYPE_URL: + throw new GrpcServiceParseException( + "LocalCredentials are not supported in grpc-java. " + + "See https://github.com/grpc/grpc-java/issues/8928"); + case TLS_CREDENTIALS_TYPE_URL: + // For this PR, we establish this structural skeleton, + // but throw an GrpcServiceParseException until the exact stream conversions are + // merged. + throw new GrpcServiceParseException( + "TlsCredentials input stream construction pending."); + default: + return Optional.empty(); + } + } catch (InvalidProtocolBufferException e) { + throw new GrpcServiceParseException("Failed to parse channel credentials: " + e.getMessage()); + } + } + + private static ConfiguredChannelCredentials extractChannelCredentials( + List channelCredentialPlugins) throws GrpcServiceParseException { + for (Any cred : channelCredentialPlugins) { + Optional parsed = channelCredsFromProto(cred); + if (parsed.isPresent()) { + return parsed.get(); + } + } + throw new GrpcServiceParseException("No valid supported channel_credentials found"); + } + + private static Optional callCredsFromProto(Any cred) + throws GrpcServiceParseException { + if (cred.is(AccessTokenCredentials.class)) { + try { + AccessTokenCredentials accessToken = cred.unpack(AccessTokenCredentials.class); + if (accessToken.getToken().isEmpty()) { + throw new GrpcServiceParseException("Missing or empty access token in call credentials."); + } + return Optional + .of(new SecurityAwareAccessTokenCredentials(MoreCallCredentials.from(OAuth2Credentials + .create(new AccessToken(accessToken.getToken(), new Date(Long.MAX_VALUE)))))); + } catch (InvalidProtocolBufferException e) { + throw new GrpcServiceParseException( + "Failed to parse access token credentials: " + e.getMessage()); + } + } + return Optional.empty(); + } + + private static Optional extractCallCredentials(List callCredentialPlugins) + throws GrpcServiceParseException { + List creds = new ArrayList<>(); + for (Any cred : callCredentialPlugins) { + Optional parsed = callCredsFromProto(cred); + if (parsed.isPresent()) { + creds.add(parsed.get()); + } + } + return creds.stream().reduce(CompositeCallCredentials::new); + } + + private static final class SecurityAwareAccessTokenCredentials extends CallCredentials { + + private final CallCredentials delegate; + + SecurityAwareAccessTokenCredentials(CallCredentials delegate) { + this.delegate = delegate; + } + + @Override + public void applyRequestMetadata(RequestInfo requestInfo, Executor appExecutor, + MetadataApplier applier) { + if (requestInfo.getSecurityLevel() == SecurityLevel.PRIVACY_AND_INTEGRITY) { + delegate.applyRequestMetadata(requestInfo, appExecutor, applier); + } else { + applier.apply(new Metadata()); + } + } + } + + static final class ProtoChannelCredsConfig + implements ConfiguredChannelCredentials.ChannelCredsConfig { + private final String type; + private final Any configProto; + + ProtoChannelCredsConfig(String type, Any configProto) { + this.type = type; + this.configProto = configProto; + } + + @Override + public String type() { + return type; + } + + Any configProto() { + return configProto; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ProtoChannelCredsConfig that = (ProtoChannelCredsConfig) o; + return java.util.Objects.equals(type, that.type) + && java.util.Objects.equals(configProto, that.configProto); + } + + @Override + public int hashCode() { + return java.util.Objects.hash(type, configProto); + } + } + + + +} diff --git a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java index 74c28ba2d2d..5100537aea2 100644 --- a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java +++ b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.ChannelCredentials; import io.grpc.ClientCall; @@ -30,39 +31,94 @@ import io.grpc.Status; import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.XdsTransportFactory; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +/** + * A factory for creating gRPC-based transports for xDS communication. + * + *

WARNING: This class reuses channels when possible, based on the provided {@link + * Bootstrapper.ServerInfo} with important considerations. The {@link Bootstrapper.ServerInfo} + * includes {@link ChannelCredentials}, which is compared by reference equality. This means every + * {@link Bootstrapper.BootstrapInfo} would have non-equal copies of {@link + * Bootstrapper.ServerInfo}, even if they all represent the same xDS server configuration. For gRPC + * name resolution with the {@code xds} and {@code google-c2p} scheme, this transport sharing works + * as expected as it internally reuses a single {@link Bootstrapper.BootstrapInfo} instance. + * Otherwise, new transports would be created for each {@link Bootstrapper.ServerInfo} despite them + * possibly representing the same xDS server configuration and defeating the purpose of transport + * sharing. + */ final class GrpcXdsTransportFactory implements XdsTransportFactory { - static final GrpcXdsTransportFactory DEFAULT_XDS_TRANSPORT_FACTORY = - new GrpcXdsTransportFactory(); + private final CallCredentials callCredentials; + // The map of xDS server info to its corresponding gRPC xDS transport. + // This enables reusing and sharing the same underlying gRPC channel. + // + // NOTE: ConcurrentHashMap is used as a per-entry lock and all reads and writes must be a mutation + // via the ConcurrentHashMap APIs to acquire the per-entry lock in order to ensure thread safety + // for reference counting of each GrpcXdsTransport instance. + private static final Map xdsServerInfoToTransportMap = + new ConcurrentHashMap<>(); + + GrpcXdsTransportFactory(CallCredentials callCredentials) { + this.callCredentials = callCredentials; + } @Override public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { - return new GrpcXdsTransport(serverInfo); + return xdsServerInfoToTransportMap.compute( + serverInfo, + (info, transport) -> { + if (transport == null) { + transport = new GrpcXdsTransport(serverInfo, callCredentials); + } + ++transport.refCount; + return transport; + }); } @VisibleForTesting public XdsTransport createForTest(ManagedChannel channel) { - return new GrpcXdsTransport(channel); + return new GrpcXdsTransport(channel, callCredentials, null); } @VisibleForTesting static class GrpcXdsTransport implements XdsTransport { private final ManagedChannel channel; + private final CallCredentials callCredentials; + private final Bootstrapper.ServerInfo serverInfo; + // Must only be accessed via the ConcurrentHashMap APIs which act as the locking methods. + private int refCount = 0; public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) { + this(serverInfo, null); + } + + @VisibleForTesting + public GrpcXdsTransport(ManagedChannel channel) { + this(channel, null, null); + } + + public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials callCredentials) { String target = serverInfo.target(); ChannelCredentials channelCredentials = (ChannelCredentials) serverInfo.implSpecificConfig(); this.channel = Grpc.newChannelBuilder(target, channelCredentials) .keepAliveTime(5, TimeUnit.MINUTES) .build(); + this.callCredentials = callCredentials; + this.serverInfo = serverInfo; } @VisibleForTesting - public GrpcXdsTransport(ManagedChannel channel) { + public GrpcXdsTransport( + ManagedChannel channel, + CallCredentials callCredentials, + Bootstrapper.ServerInfo serverInfo) { this.channel = checkNotNull(channel, "channel"); + this.callCredentials = callCredentials; + this.serverInfo = serverInfo; } @Override @@ -72,7 +128,8 @@ public StreamingCall createStreamingCall( MethodDescriptor.Marshaller respMarshaller) { Context prevContext = Context.ROOT.attach(); try { - return new XdsStreamingCall<>(fullMethodName, reqMarshaller, respMarshaller); + return new XdsStreamingCall<>( + fullMethodName, reqMarshaller, respMarshaller, callCredentials); } finally { Context.ROOT.detach(prevContext); } @@ -81,7 +138,19 @@ public StreamingCall createStreamingCall( @Override public void shutdown() { - channel.shutdown(); + if (serverInfo == null) { + channel.shutdown(); + return; + } + xdsServerInfoToTransportMap.computeIfPresent( + serverInfo, + (info, transport) -> { + if (--transport.refCount == 0) { // Prefix decrement and return the updated value. + transport.channel.shutdown(); + return null; // Remove mapping. + } + return transport; + }); } private class XdsStreamingCall implements @@ -89,16 +158,21 @@ private class XdsStreamingCall implements private final ClientCall call; - public XdsStreamingCall(String methodName, MethodDescriptor.Marshaller reqMarshaller, - MethodDescriptor.Marshaller respMarshaller) { - this.call = channel.newCall( - MethodDescriptor.newBuilder() - .setFullMethodName(methodName) - .setType(MethodDescriptor.MethodType.BIDI_STREAMING) - .setRequestMarshaller(reqMarshaller) - .setResponseMarshaller(respMarshaller) - .build(), - CallOptions.DEFAULT); // TODO(zivy): support waitForReady + public XdsStreamingCall( + String methodName, + MethodDescriptor.Marshaller reqMarshaller, + MethodDescriptor.Marshaller respMarshaller, + CallCredentials callCredentials) { + this.call = + channel.newCall( + MethodDescriptor.newBuilder() + .setFullMethodName(methodName) + .setType(MethodDescriptor.MethodType.BIDI_STREAMING) + .setRequestMarshaller(reqMarshaller) + .setResponseMarshaller(respMarshaller) + .build(), + CallOptions.DEFAULT.withCallCredentials( + callCredentials)); // TODO(zivy): support waitForReady } @Override diff --git a/xds/src/main/java/io/grpc/xds/InternalGrpcBootstrapperImpl.java b/xds/src/main/java/io/grpc/xds/InternalGrpcBootstrapperImpl.java index 929619c11d7..7bbc2a6dfca 100644 --- a/xds/src/main/java/io/grpc/xds/InternalGrpcBootstrapperImpl.java +++ b/xds/src/main/java/io/grpc/xds/InternalGrpcBootstrapperImpl.java @@ -17,8 +17,9 @@ package io.grpc.xds; import io.grpc.Internal; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.XdsInitializationException; -import java.io.IOException; +import java.util.Map; /** * Internal accessors for GrpcBootstrapperImpl. @@ -27,7 +28,8 @@ public final class InternalGrpcBootstrapperImpl { private InternalGrpcBootstrapperImpl() {} // prevent instantiation - public static String getJsonContent() throws XdsInitializationException, IOException { - return new GrpcBootstrapperImpl().getJsonContent(); + public static BootstrapInfo parseBootstrap(Map bootstrap) + throws XdsInitializationException { + return new GrpcBootstrapperImpl().bootstrap(bootstrap); } } diff --git a/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java b/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java index 54e6c748cd5..476adbf9cfd 100644 --- a/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java +++ b/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java @@ -19,8 +19,6 @@ import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC; import io.grpc.Internal; import io.grpc.ServerInterceptor; -import io.grpc.xds.RbacConfig; -import io.grpc.xds.RbacFilter; /** This class exposes some functionality in RbacFilter to other packages. */ @Internal @@ -30,11 +28,12 @@ private InternalRbacFilter() {} /** Parses RBAC filter config and creates AuthorizationServerInterceptor. */ public static ServerInterceptor createInterceptor(RBAC rbac) { - ConfigOrError filterConfig = RbacFilter.parseRbacConfig(rbac); + ConfigOrError filterConfig = RbacFilter.Provider.parseRbacConfig(rbac); if (filterConfig.errorDetail != null) { throw new IllegalArgumentException( String.format("Failed to parse Rbac policy: %s", filterConfig.errorDetail)); } - return new RbacFilter().buildServerInterceptor(filterConfig.config, null); + return new RbacFilter.Provider().newInstance("internalRbacFilter") + .buildServerInterceptor(filterConfig.config, null); } } diff --git a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java index 0073cce1a88..cc5ff128274 100644 --- a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java @@ -16,8 +16,11 @@ package io.grpc.xds; +import io.grpc.CallCredentials; import io.grpc.Internal; +import io.grpc.MetricRecorder; import io.grpc.internal.ObjectPool; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsInitializationException; import java.util.Map; @@ -30,12 +33,79 @@ public final class InternalSharedXdsClientPoolProvider { // Prevent instantiation private InternalSharedXdsClientPoolProvider() {} + /** + * Override the global bootstrap. + * + * @deprecated Use InternalGrpcBootstrapperImpl.parseBootstrap() and pass the result to + * getOrCreate(). + */ + @Deprecated public static void setDefaultProviderBootstrapOverride(Map bootstrap) { - SharedXdsClientPoolProvider.getDefaultProvider().setBootstrapOverride(bootstrap); + GrpcBootstrapperImpl.setDefaultBootstrapOverride(bootstrap); } + /** + * Get an XdsClient pool. + * + * @deprecated Use InternalGrpcBootstrapperImpl.parseBootstrap() and pass the result to the other + * getOrCreate(). + */ + @Deprecated public static ObjectPool getOrCreate(String target) throws XdsInitializationException { - return SharedXdsClientPoolProvider.getDefaultProvider().getOrCreate(target); + return getOrCreate(target, new MetricRecorder() {}); + } + + /** + * Get an XdsClient pool. + * + * @deprecated Use InternalGrpcBootstrapperImpl.parseBootstrap() and pass the result to the other + * getOrCreate(). + */ + @Deprecated + public static ObjectPool getOrCreate(String target, MetricRecorder metricRecorder) + throws XdsInitializationException { + return getOrCreate(target, metricRecorder, null); + } + + /** + * Get an XdsClient pool. + * + * @deprecated Use InternalGrpcBootstrapperImpl.parseBootstrap() and pass the result to the other + * getOrCreate(). + */ + @Deprecated + public static ObjectPool getOrCreate( + String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials) + throws XdsInitializationException { + return SharedXdsClientPoolProvider.getDefaultProvider() + .getOrCreate(target, metricRecorder, transportCallCredentials); + } + + public static XdsClientResult getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder, + CallCredentials transportCallCredentials) { + return new XdsClientResult(SharedXdsClientPoolProvider.getDefaultProvider() + .getOrCreate(target, bootstrapInfo, metricRecorder, transportCallCredentials)); + } + + /** + * An ObjectPool, except without exposing io.grpc.internal, which must not be used for + * cross-package APIs. + */ + public static final class XdsClientResult { + private final ObjectPool xdsClientPool; + + XdsClientResult(ObjectPool xdsClientPool) { + this.xdsClientPool = xdsClientPool; + } + + public XdsClient getObject() { + return xdsClientPool.getObject(); + } + + public XdsClient returnObject(XdsClient xdsClient) { + return xdsClientPool.returnObject(xdsClient); + } } } diff --git a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java index aaaeb198d21..ed70e6f5e78 100644 --- a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java +++ b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 The gRPC Authors + * Copyright 2024 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,91 +18,19 @@ import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; -import io.grpc.Grpc; import io.grpc.Internal; -import io.grpc.NameResolver; -import io.grpc.internal.ObjectPool; -import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; -import io.grpc.xds.client.Locality; -import io.grpc.xds.client.XdsClient; -import io.grpc.xds.internal.security.SslContextProviderSupplier; /** * Internal attributes used for xDS implementation. Do not use. */ @Internal public final class InternalXdsAttributes { - - // TODO(sanjaypujare): move to xds internal package. - /** Attribute key for SslContextProviderSupplier (used from client) for a subchannel. */ - @Grpc.TransportAttr - public static final Attributes.Key - ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER = - Attributes.Key.create("io.grpc.xds.internal.security.SslContextProviderSupplier"); - - /** - * Attribute key for passing around the XdsClient object pool across NameResolver/LoadBalancers. - */ - @NameResolver.ResolutionResultAttr - static final Attributes.Key> XDS_CLIENT_POOL = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.xdsClientPool"); - - /** - * Attribute key for obtaining the global provider that provides atomics for aggregating - * outstanding RPCs sent to each cluster. - */ - @NameResolver.ResolutionResultAttr - static final Attributes.Key CALL_COUNTER_PROVIDER = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.callCounterProvider"); - - /** - * Map from localities to their weights. - */ - @NameResolver.ResolutionResultAttr - static final Attributes.Key ATTR_LOCALITY_WEIGHT = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.localityWeight"); - /** * Name of the cluster that provides this EquivalentAddressGroup. */ - @Internal @EquivalentAddressGroup.Attr public static final Attributes.Key ATTR_CLUSTER_NAME = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.clusterName"); - - /** - * The locality that this EquivalentAddressGroup is in. - */ - @EquivalentAddressGroup.Attr - static final Attributes.Key ATTR_LOCALITY = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.locality"); - - /** - * The name of the locality that this EquivalentAddressGroup is in. - */ - @EquivalentAddressGroup.Attr - static final Attributes.Key ATTR_LOCALITY_NAME = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.localityName"); - - /** - * Endpoint weight for load balancing purposes. - */ - @EquivalentAddressGroup.Attr - static final Attributes.Key ATTR_SERVER_WEIGHT = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.serverWeight"); - - /** - * Filter chain match for network filters. - */ - @Grpc.TransportAttr - static final Attributes.Key - ATTR_FILTER_CHAIN_SELECTOR_MANAGER = Attributes.Key.create( - "io.grpc.xds.InternalXdsAttributes.filterChainSelectorManager"); - - /** Grace time to use when draining. Null for an infinite grace time. */ - @Grpc.TransportAttr - static final Attributes.Key ATTR_DRAIN_GRACE_NANOS = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.drainGraceTime"); + XdsAttributes.ATTR_CLUSTER_NAME; private InternalXdsAttributes() {} } diff --git a/xds/src/main/java/io/grpc/xds/LazyLoadBalancer.java b/xds/src/main/java/io/grpc/xds/LazyLoadBalancer.java index 87f1b72ca47..b5f09c4ea93 100644 --- a/xds/src/main/java/io/grpc/xds/LazyLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/LazyLoadBalancer.java @@ -99,19 +99,16 @@ public void requestConnection() { @Override public void shutdown() { + delegate = new NoopLoadBalancer(); } private final class LazyPicker extends SubchannelPicker { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { + // activate() is a no-op after shutdown() helper.getSynchronizationContext().execute(LazyDelegate.this::activate); return PickResult.withNoResult(); } - - @Override - public void requestConnection() { - helper.getSynchronizationContext().execute(LazyDelegate.this::requestConnection); - } } } @@ -126,4 +123,17 @@ public Factory(LoadBalancer.Factory delegate) { return new LazyLoadBalancer(helper, delegate); } } + + private static final class NoopLoadBalancer extends LoadBalancer { + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + return Status.OK; + } + + @Override + public void handleNameResolutionError(Status error) {} + + @Override + public void shutdown() {} + } } diff --git a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java index 6c13530ff49..1f23f2a4af5 100644 --- a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java @@ -32,7 +32,6 @@ import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.ConnectivityState; -import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; import io.grpc.Metadata; @@ -55,7 +54,7 @@ final class LeastRequestLoadBalancer extends MultiChildLoadBalancer { private final ThreadSafeRandom random; - private SubchannelPicker currentPicker = new EmptyPicker(); + private SubchannelPicker currentPicker = new FixedResultPicker(PickResult.withNoResult()); private int choiceCount = DEFAULT_CHOICE_COUNT; LeastRequestLoadBalancer(Helper helper) { @@ -114,7 +113,7 @@ protected void updateOverallBalancingState() { } } if (isConnecting) { - updateBalancingState(CONNECTING, new EmptyPicker()); + updateBalancingState(CONNECTING, new FixedResultPicker(PickResult.withNoResult())); } else { // Give it all the failing children and let it randomly pick among them updateBalancingState(TRANSIENT_FAILURE, @@ -155,7 +154,6 @@ private static AtomicInteger getInFlights(ChildLbState childLbState) { static final class ReadyPicker extends SubchannelPicker { private final List childPickers; // non-empty private final List childInFlights; // 1:1 with childPickers - private final List childEags; // 1:1 with childPickers private final int choiceCount; private final ThreadSafeRandom random; private final int hashCode; @@ -164,11 +162,9 @@ static final class ReadyPicker extends SubchannelPicker { checkArgument(!childLbStates.isEmpty(), "empty list"); this.childPickers = new ArrayList<>(childLbStates.size()); this.childInFlights = new ArrayList<>(childLbStates.size()); - this.childEags = new ArrayList<>(childLbStates.size()); for (ChildLbState state : childLbStates) { childPickers.add(state.getCurrentPicker()); childInFlights.add(getInFlights(state)); - childEags.add(state.getEag()); } this.choiceCount = choiceCount; this.random = checkNotNull(random, "random"); @@ -224,11 +220,6 @@ List getChildPickers() { return childPickers; } - @VisibleForTesting - List getChildEags() { - return childEags; - } - @Override public int hashCode() { return hashCode; diff --git a/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java b/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java index e08ea0fab43..5fd8ec5526e 100644 --- a/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java +++ b/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java @@ -91,6 +91,7 @@ class LoadBalancerConfigFactory { static final String SHUFFLE_ADDRESS_LIST_FIELD_NAME = "shuffleAddressList"; static final String ERROR_UTILIZATION_PENALTY = "errorUtilizationPenalty"; + static final String METRIC_NAMES_FOR_COMPUTING_UTILIZATION = "metricNamesForComputingUtilization"; /** * Factory method for creating a new {link LoadBalancerConfigConverter} for a given xDS {@link @@ -134,11 +135,9 @@ class LoadBalancerConfigFactory { * the given config values. */ private static ImmutableMap buildWrrConfig(String blackoutPeriod, - String weightExpirationPeriod, - String oobReportingPeriod, - Boolean enableOobLoadReport, - String weightUpdatePeriod, - Float errorUtilizationPenalty) { + String weightExpirationPeriod, String oobReportingPeriod, Boolean enableOobLoadReport, + String weightUpdatePeriod, Float errorUtilizationPenalty, + ImmutableList metricNamesForComputingUtilization) { ImmutableMap.Builder configBuilder = ImmutableMap.builder(); if (blackoutPeriod != null) { configBuilder.put(BLACK_OUT_PERIOD, blackoutPeriod); @@ -158,6 +157,10 @@ class LoadBalancerConfigFactory { if (errorUtilizationPenalty != null) { configBuilder.put(ERROR_UTILIZATION_PENALTY, errorUtilizationPenalty); } + if (metricNamesForComputingUtilization != null + && !metricNamesForComputingUtilization.isEmpty()) { + configBuilder.put(METRIC_NAMES_FOR_COMPUTING_UTILIZATION, metricNamesForComputingUtilization); + } return ImmutableMap.of(WeightedRoundRobinLoadBalancerProvider.SCHEME, configBuilder.buildOrThrow()); } @@ -284,7 +287,7 @@ static class LoadBalancingPolicyConverter { } private static ImmutableMap convertWeightedRoundRobinConfig( - ClientSideWeightedRoundRobin wrr) throws ResourceInvalidException { + ClientSideWeightedRoundRobin wrr) throws ResourceInvalidException { try { return buildWrrConfig( wrr.hasBlackoutPeriod() ? Durations.toString(wrr.getBlackoutPeriod()) : null, @@ -293,7 +296,8 @@ static class LoadBalancingPolicyConverter { wrr.hasOobReportingPeriod() ? Durations.toString(wrr.getOobReportingPeriod()) : null, wrr.hasEnableOobLoadReport() ? wrr.getEnableOobLoadReport().getValue() : null, wrr.hasWeightUpdatePeriod() ? Durations.toString(wrr.getWeightUpdatePeriod()) : null, - wrr.hasErrorUtilizationPenalty() ? wrr.getErrorUtilizationPenalty().getValue() : null); + wrr.hasErrorUtilizationPenalty() ? wrr.getErrorUtilizationPenalty().getValue() : null, + ImmutableList.copyOf(wrr.getMetricNamesForComputingUtilizationList())); } catch (IllegalArgumentException ex) { throw new ResourceInvalidException("Invalid duration in weighted round robin config: " + ex.getMessage()); diff --git a/xds/src/main/java/io/grpc/xds/MessagePrinter.java b/xds/src/main/java/io/grpc/xds/MessagePrinter.java index 5927bfd517e..d6fdaa81dd7 100644 --- a/xds/src/main/java/io/grpc/xds/MessagePrinter.java +++ b/xds/src/main/java/io/grpc/xds/MessagePrinter.java @@ -16,6 +16,7 @@ package io.grpc.xds; +import com.github.xds.type.v3.TypedStruct; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; @@ -32,8 +33,11 @@ import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBACPerRoute; import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; +import io.envoyproxy.envoy.extensions.load_balancing_policies.round_robin.v3.RoundRobin; +import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; +import io.envoyproxy.envoy.service.discovery.v3.Resource; import io.grpc.xds.client.MessagePrettyPrinter; /** @@ -52,6 +56,7 @@ private static class LazyHolder { private static JsonFormat.Printer newPrinter() { TypeRegistry.Builder registry = TypeRegistry.newBuilder() + .add(Resource.getDescriptor()) .add(Listener.getDescriptor()) .add(HttpConnectionManager.getDescriptor()) .add(HTTPFault.getDescriptor()) @@ -65,7 +70,10 @@ private static JsonFormat.Printer newPrinter() { .add(RouteConfiguration.getDescriptor()) .add(Cluster.getDescriptor()) .add(ClusterConfig.getDescriptor()) - .add(ClusterLoadAssignment.getDescriptor()); + .add(ClusterLoadAssignment.getDescriptor()) + .add(WrrLocality.getDescriptor()) + .add(TypedStruct.getDescriptor()) + .add(RoundRobin.getDescriptor()); try { @SuppressWarnings("unchecked") Class routeLookupClusterSpecifierClass = diff --git a/xds/src/main/java/io/grpc/xds/MetadataRegistry.java b/xds/src/main/java/io/grpc/xds/MetadataRegistry.java new file mode 100644 index 00000000000..b79a61a261a --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/MetadataRegistry.java @@ -0,0 +1,125 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Any; +import com.google.protobuf.Struct; +import io.envoyproxy.envoy.config.core.v3.Metadata; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser; +import io.grpc.xds.XdsEndpointResource.AddressMetadataParser; +import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; +import io.grpc.xds.internal.ProtobufJsonConverter; +import java.util.HashMap; +import java.util.Map; + +/** + * Registry for parsing cluster metadata values. + * + *

This class maintains a mapping of type URLs to {@link MetadataValueParser} instances, + * allowing for the parsing of different metadata types. + */ +final class MetadataRegistry { + private static final MetadataRegistry INSTANCE = new MetadataRegistry(); + + private final Map supportedParsers = new HashMap<>(); + + private MetadataRegistry() { + registerParser(new AudienceMetadataParser()); + registerParser(new AddressMetadataParser()); + } + + static MetadataRegistry getInstance() { + return INSTANCE; + } + + MetadataValueParser findParser(String typeUrl) { + return supportedParsers.get(typeUrl); + } + + @VisibleForTesting + void registerParser(MetadataValueParser parser) { + supportedParsers.put(parser.getTypeUrl(), parser); + } + + void removeParser(MetadataValueParser parser) { + supportedParsers.remove(parser.getTypeUrl()); + } + + /** + * Parses cluster metadata into a structured map. + * + *

Values in {@code typed_filter_metadata} take precedence over + * {@code filter_metadata} when keys overlap, following Envoy API behavior. See + * + * Envoy metadata documentation for details. + * + * @param metadata the {@link Metadata} containing the fields to parse. + * @return an immutable map of parsed metadata. + * @throws ResourceInvalidException if parsing {@code typed_filter_metadata} fails. + */ + public ImmutableMap parseMetadata(Metadata metadata) + throws ResourceInvalidException { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + + // Process typed_filter_metadata + for (Map.Entry entry : metadata.getTypedFilterMetadataMap().entrySet()) { + String key = entry.getKey(); + Any value = entry.getValue(); + MetadataValueParser parser = findParser(value.getTypeUrl()); + if (parser != null) { + try { + Object parsedValue = parser.parse(value); + parsedMetadata.put(key, parsedValue); + } catch (ResourceInvalidException e) { + throw new ResourceInvalidException( + String.format("Failed to parse metadata key: %s, type: %s. Error: %s", + key, value.getTypeUrl(), e.getMessage()), e); + } + } + } + // building once to reuse in the next loop + ImmutableMap intermediateParsedMetadata = parsedMetadata.build(); + + // Process filter_metadata for remaining keys + for (Map.Entry entry : metadata.getFilterMetadataMap().entrySet()) { + String key = entry.getKey(); + if (!intermediateParsedMetadata.containsKey(key)) { + Struct structValue = entry.getValue(); + Object jsonValue = ProtobufJsonConverter.convertToJson(structValue); + parsedMetadata.put(key, jsonValue); + } + } + + return parsedMetadata.build(); + } + + interface MetadataValueParser { + + String getTypeUrl(); + + /** + * Parses the given {@link Any} object into a specific metadata value. + * + * @param any the {@link Any} object to parse. + * @return the parsed metadata value. + * @throws ResourceInvalidException if the parsing fails. + */ + Object parse(Any any) throws ResourceInvalidException; + } +} diff --git a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java index 259ae7406c1..6e4566de76d 100644 --- a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java @@ -91,6 +91,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { checkNotNull(config, "missing priority lb config"); priorityNames = config.priorities; priorityConfigs = config.childConfigs; + Status status = Status.OK; Set prioritySet = new HashSet<>(config.priorities); ArrayList childKeys = new ArrayList<>(children.keySet()); for (String priority : childKeys) { @@ -105,12 +106,18 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { for (String priority : priorityNames) { ChildLbState childLbState = children.get(priority); if (childLbState != null) { - childLbState.updateResolvedAddresses(); + Status newStatus = childLbState.updateResolvedAddresses(); + if (!newStatus.isOk()) { + status = newStatus; + } } } handlingResolvedAddresses = false; - tryNextPriority(); - return Status.OK; + Status newStatus = tryNextPriority(); + if (!newStatus.isOk()) { + status = newStatus; + } + return status; } @Override @@ -140,19 +147,19 @@ public void shutdown() { children.clear(); } - private void tryNextPriority() { + private Status tryNextPriority() { for (int i = 0; i < priorityNames.size(); i++) { String priority = priorityNames.get(i); if (!children.containsKey(priority)) { ChildLbState child = new ChildLbState(priority, priorityConfigs.get(priority).ignoreReresolution); children.put(priority, child); - updateOverallState(priority, CONNECTING, new FixedResultPicker(PickResult.withNoResult())); + // Child is created in CONNECTING with pending failOverTimer + updateOverallState(priority, child.connectivityState, child.picker); // Calling the child's updateResolvedAddresses() can result in tryNextPriority() being // called recursively. We need to be sure to be done with processing here before it is // called. - child.updateResolvedAddresses(); - return; // Give priority i time to connect. + return child.updateResolvedAddresses(); // Give priority i time to connect. } ChildLbState child = children.get(priority); child.reactivate(); @@ -165,23 +172,26 @@ private void tryNextPriority() { children.get(p).deactivate(); } } - return; + return Status.OK; } - if (child.failOverTimer != null && child.failOverTimer.isPending()) { + if (child.failOverTimer.isPending()) { updateOverallState(priority, child.connectivityState, child.picker); - return; // Give priority i time to connect. + return Status.OK; // Give priority i time to connect. } - if (priority.equals(currentPriority) && child.connectivityState != TRANSIENT_FAILURE) { - // If the current priority is not changed into TRANSIENT_FAILURE, keep using it. + } + for (int i = 0; i < priorityNames.size(); i++) { + String priority = priorityNames.get(i); + ChildLbState child = children.get(priority); + if (child.connectivityState.equals(CONNECTING)) { updateOverallState(priority, child.connectivityState, child.picker); - return; + return Status.OK; } } - // TODO(zdapeng): Include error details of each priority. logger.log(XdsLogLevel.DEBUG, "All priority failed"); String lastPriority = priorityNames.get(priorityNames.size() - 1); - SubchannelPicker errorPicker = children.get(lastPriority).picker; - updateOverallState(lastPriority, TRANSIENT_FAILURE, errorPicker); + ChildLbState child = children.get(lastPriority); + updateOverallState(lastPriority, child.connectivityState, child.picker); + return Status.OK; } private void updateOverallState( @@ -224,11 +234,12 @@ public void run() { // The child is deactivated. return; } - picker = new FixedResultPicker(PickResult.withError( - Status.UNAVAILABLE.withDescription("Connection timeout for priority " + priority))); logger.log(XdsLogLevel.DEBUG, "Priority {0} failed over to next", priority); - currentPriority = null; // reset currentPriority to guarantee failover happen - tryNextPriority(); + Status status = tryNextPriority(); + if (!status.isOk()) { + // A child had a problem with the addresses/config. Request it to be refreshed + helper.refreshNameResolution(); + } } } @@ -279,10 +290,10 @@ void tearDown() { * resolvedAddresses}, or when priority lb receives a new resolved addresses while the child * already exists. */ - void updateResolvedAddresses() { + Status updateResolvedAddresses() { PriorityLbConfig config = (PriorityLbConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - lb.handleResolvedAddresses( + return lb.acceptResolvedAddresses( resolvedAddresses.toBuilder() .setAddresses(AddressFilter.filter(resolvedAddresses.getAddresses(), priority)) .setLoadBalancingPolicyConfig(config.childConfigs.get(priority).childConfig) @@ -309,13 +320,14 @@ public void updateBalancingState(final ConnectivityState newState, if (!children.containsKey(priority)) { return; } + ConnectivityState oldState = connectivityState; connectivityState = newState; picker = newPicker; if (deletionTimer != null && deletionTimer.isPending()) { return; } - if (newState.equals(CONNECTING)) { + if (newState.equals(CONNECTING) && !oldState.equals(newState)) { if (!failOverTimer.isPending() && seenReadyOrIdleSinceTransientFailure) { failOverTimer = syncContext.schedule(new FailOverTask(), 10, TimeUnit.SECONDS, executor); @@ -331,7 +343,11 @@ public void updateBalancingState(final ConnectivityState newState, // If we are currently handling newly resolved addresses, let's not try to reconfigure as // the address handling process will take care of that to provide an atomic config update. if (!handlingResolvedAddresses) { - tryNextPriority(); + Status status = tryNextPriority(); + if (!status.isOk()) { + // A child had a problem with the addresses/config. Request it to be refreshed + helper.refreshNameResolution(); + } } } diff --git a/xds/src/main/java/io/grpc/xds/RbacFilter.java b/xds/src/main/java/io/grpc/xds/RbacFilter.java index 6a55f7f193e..91df1e68802 100644 --- a/xds/src/main/java/io/grpc/xds/RbacFilter.java +++ b/xds/src/main/java/io/grpc/xds/RbacFilter.java @@ -18,7 +18,6 @@ import static com.google.common.base.Preconditions.checkNotNull; -import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; @@ -34,7 +33,6 @@ import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.Status; -import io.grpc.xds.Filter.ServerInterceptorBuilder; import io.grpc.xds.internal.MatcherParser; import io.grpc.xds.internal.Matchers; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine; @@ -66,10 +64,10 @@ import javax.annotation.Nullable; /** RBAC Http filter implementation. */ -final class RbacFilter implements Filter, ServerInterceptorBuilder { +final class RbacFilter implements Filter { private static final Logger logger = Logger.getLogger(RbacFilter.class.getName()); - static final RbacFilter INSTANCE = new RbacFilter(); + private static final RbacFilter INSTANCE = new RbacFilter(); static final String TYPE_URL = "type.googleapis.com/envoy.extensions.filters.http.rbac.v3.RBAC"; @@ -77,87 +75,99 @@ final class RbacFilter implements Filter, ServerInterceptorBuilder { private static final String TYPE_URL_OVERRIDE_CONFIG = "type.googleapis.com/envoy.extensions.filters.http.rbac.v3.RBACPerRoute"; - RbacFilter() {} + private RbacFilter() {} - @Override - public String[] typeUrls() { - return new String[] { TYPE_URL, TYPE_URL_OVERRIDE_CONFIG }; - } + static final class Provider implements Filter.Provider { + @Override + public String[] typeUrls() { + return new String[] {TYPE_URL, TYPE_URL_OVERRIDE_CONFIG}; + } - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - RBAC rbacProto; - if (!(rawProtoMessage instanceof Any)) { - return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + @Override + public boolean isServerFilter() { + return true; } - Any anyMessage = (Any) rawProtoMessage; - try { - rbacProto = anyMessage.unpack(RBAC.class); - } catch (InvalidProtocolBufferException e) { - return ConfigOrError.fromError("Invalid proto: " + e); + + @Override + public RbacFilter newInstance(String name) { + return INSTANCE; } - return parseRbacConfig(rbacProto); - } - @VisibleForTesting - static ConfigOrError parseRbacConfig(RBAC rbac) { - if (!rbac.hasRules()) { - return ConfigOrError.fromConfig(RbacConfig.create(null)); + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + RBAC rbacProto; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + } + Any anyMessage = (Any) rawProtoMessage; + try { + rbacProto = anyMessage.unpack(RBAC.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + return parseRbacConfig(rbacProto); } - io.envoyproxy.envoy.config.rbac.v3.RBAC rbacConfig = rbac.getRules(); - GrpcAuthorizationEngine.Action authAction; - switch (rbacConfig.getAction()) { - case ALLOW: - authAction = GrpcAuthorizationEngine.Action.ALLOW; - break; - case DENY: - authAction = GrpcAuthorizationEngine.Action.DENY; - break; - case LOG: + + @Override + public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { + RBACPerRoute rbacPerRoute; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + } + Any anyMessage = (Any) rawProtoMessage; + try { + rbacPerRoute = anyMessage.unpack(RBACPerRoute.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + if (rbacPerRoute.hasRbac()) { + return parseRbacConfig(rbacPerRoute.getRbac()); + } else { return ConfigOrError.fromConfig(RbacConfig.create(null)); - case UNRECOGNIZED: - default: - return ConfigOrError.fromError("Unknown rbacConfig action type: " + rbacConfig.getAction()); + } } - List policyMatchers = new ArrayList<>(); - List> sortedPolicyEntries = rbacConfig.getPoliciesMap().entrySet() - .stream() - .sorted((a,b) -> a.getKey().compareTo(b.getKey())) - .collect(Collectors.toList()); - for (Map.Entry entry: sortedPolicyEntries) { - try { - Policy policy = entry.getValue(); - if (policy.hasCondition() || policy.hasCheckedCondition()) { + + static ConfigOrError parseRbacConfig(RBAC rbac) { + if (!rbac.hasRules()) { + return ConfigOrError.fromConfig(RbacConfig.create(null)); + } + io.envoyproxy.envoy.config.rbac.v3.RBAC rbacConfig = rbac.getRules(); + GrpcAuthorizationEngine.Action authAction; + switch (rbacConfig.getAction()) { + case ALLOW: + authAction = GrpcAuthorizationEngine.Action.ALLOW; + break; + case DENY: + authAction = GrpcAuthorizationEngine.Action.DENY; + break; + case LOG: + return ConfigOrError.fromConfig(RbacConfig.create(null)); + case UNRECOGNIZED: + default: return ConfigOrError.fromError( - "Policy.condition and Policy.checked_condition must not set: " + entry.getKey()); + "Unknown rbacConfig action type: " + rbacConfig.getAction()); + } + List policyMatchers = new ArrayList<>(); + List> sortedPolicyEntries = rbacConfig.getPoliciesMap().entrySet() + .stream() + .sorted((a,b) -> a.getKey().compareTo(b.getKey())) + .collect(Collectors.toList()); + for (Map.Entry entry: sortedPolicyEntries) { + try { + Policy policy = entry.getValue(); + if (policy.hasCondition() || policy.hasCheckedCondition()) { + return ConfigOrError.fromError( + "Policy.condition and Policy.checked_condition must not set: " + entry.getKey()); + } + policyMatchers.add(PolicyMatcher.create(entry.getKey(), + parsePermissionList(policy.getPermissionsList()), + parsePrincipalList(policy.getPrincipalsList()))); + } catch (Exception e) { + return ConfigOrError.fromError("Encountered error parsing policy: " + e); } - policyMatchers.add(PolicyMatcher.create(entry.getKey(), - parsePermissionList(policy.getPermissionsList()), - parsePrincipalList(policy.getPrincipalsList()))); - } catch (Exception e) { - return ConfigOrError.fromError("Encountered error parsing policy: " + e); } - } - return ConfigOrError.fromConfig(RbacConfig.create( - AuthConfig.create(policyMatchers, authAction))); - } - - @Override - public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { - RBACPerRoute rbacPerRoute; - if (!(rawProtoMessage instanceof Any)) { - return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); - } - Any anyMessage = (Any) rawProtoMessage; - try { - rbacPerRoute = anyMessage.unpack(RBACPerRoute.class); - } catch (InvalidProtocolBufferException e) { - return ConfigOrError.fromError("Invalid proto: " + e); - } - if (rbacPerRoute.hasRbac()) { - return parseRbacConfig(rbacPerRoute.getRbac()); - } else { - return ConfigOrError.fromConfig(RbacConfig.create(null)); + return ConfigOrError.fromConfig(RbacConfig.create( + AuthConfig.create(policyMatchers, authAction))); } } @@ -266,8 +276,13 @@ private static Matcher parsePrincipal(Principal principal) { return createSourceIpMatcher(principal.getDirectRemoteIp()); case REMOTE_IP: return createSourceIpMatcher(principal.getRemoteIp()); - case SOURCE_IP: - return createSourceIpMatcher(principal.getSourceIp()); + case SOURCE_IP: { + // gRFC A41 has identical handling of source_ip as remote_ip and direct_remote_ip and + // pre-dates the deprecation. + @SuppressWarnings("deprecation") + CidrRange sourceIp = principal.getSourceIp(); + return createSourceIpMatcher(sourceIp); + } case HEADER: return parseHeaderMatcher(principal.getHeader()); case NOT_ID: diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java index 4f93974b52c..513f4d643ea 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java @@ -25,6 +25,8 @@ import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Joiner; import com.google.common.base.MoreObjects; import com.google.common.collect.HashMultiset; import com.google.common.collect.Multiset; @@ -34,9 +36,11 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; +import io.grpc.Metadata; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.util.MultiChildLoadBalancer; +import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import io.grpc.xds.client.XdsLogger; import io.grpc.xds.client.XdsLogger.XdsLogLevel; import java.net.SocketAddress; @@ -47,6 +51,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -69,13 +74,21 @@ final class RingHashLoadBalancer extends MultiChildLoadBalancer { new LazyLoadBalancer.Factory(pickFirstLbProvider); private final XdsLogger logger; private final SynchronizationContext syncContext; + private final ThreadSafeRandom random; private List ring; + @Nullable private Metadata.Key requestHashHeaderKey; RingHashLoadBalancer(Helper helper) { + this(helper, ThreadSafeRandomImpl.instance); + } + + @VisibleForTesting + RingHashLoadBalancer(Helper helper, ThreadSafeRandom random) { super(helper); syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); logger = XdsLogger.withLogId(InternalLogId.allocate("ring_hash_lb", helper.getAuthority())); logger.log(XdsLogLevel.INFO, "Created"); + this.random = checkNotNull(random, "random"); } @Override @@ -87,62 +100,50 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { return addressValidityStatus; } - try { - resolvingAddresses = true; - AcceptResolvedAddrRetVal acceptRetVal = acceptResolvedAddressesInternal(resolvedAddresses); - if (!acceptRetVal.status.isOk()) { - return acceptRetVal.status; - } - - // Now do the ringhash specific logic with weights and building the ring - RingHashConfig config = (RingHashConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - if (config == null) { - throw new IllegalArgumentException("Missing RingHash configuration"); + // Now do the ringhash specific logic with weights and building the ring + RingHashConfig config = (RingHashConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + if (config == null) { + throw new IllegalArgumentException("Missing RingHash configuration"); + } + requestHashHeaderKey = + config.requestHashHeader.isEmpty() + ? null + : Metadata.Key.of(config.requestHashHeader, Metadata.ASCII_STRING_MARSHALLER); + Map serverWeights = new HashMap<>(); + long totalWeight = 0L; + for (EquivalentAddressGroup eag : addrList) { + Long weight = eag.getAttributes().get(XdsAttributes.ATTR_SERVER_WEIGHT); + // Support two ways of server weighing: either multiple instances of the same address + // or each address contains a per-address weight attribute. If a weight is not provided, + // each occurrence of the address will be counted a weight value of one. + if (weight == null) { + weight = 1L; } - Map serverWeights = new HashMap<>(); - long totalWeight = 0L; - for (EquivalentAddressGroup eag : addrList) { - Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT); - // Support two ways of server weighing: either multiple instances of the same address - // or each address contains a per-address weight attribute. If a weight is not provided, - // each occurrence of the address will be counted a weight value of one. - if (weight == null) { - weight = 1L; - } - totalWeight += weight; - EquivalentAddressGroup addrKey = stripAttrs(eag); - if (serverWeights.containsKey(addrKey)) { - serverWeights.put(addrKey, serverWeights.get(addrKey) + weight); - } else { - serverWeights.put(addrKey, weight); - } + totalWeight += weight; + EquivalentAddressGroup addrKey = stripAttrs(eag); + if (serverWeights.containsKey(addrKey)) { + serverWeights.put(addrKey, serverWeights.get(addrKey) + weight); + } else { + serverWeights.put(addrKey, weight); } - // Calculate scale - long minWeight = Collections.min(serverWeights.values()); - double normalizedMinWeight = (double) minWeight / totalWeight; - // Scale up the number of hashes per host such that the least-weighted host gets a whole - // number of hashes on the the ring. Other hosts might not end up with whole numbers, and - // that's fine (the ring-building algorithm can handle this). This preserves the original - // implementation's behavior: when weights aren't provided, all hosts should get an equal - // number of hashes. In the case where this number exceeds the max_ring_size, it's scaled - // back down to fit. - double scale = Math.min( - Math.ceil(normalizedMinWeight * config.minRingSize) / normalizedMinWeight, - (double) config.maxRingSize); - - // Build the ring - ring = buildRing(serverWeights, totalWeight, scale); - - // Must update channel picker before return so that new RPCs will not be routed to deleted - // clusters and resolver can remove them in service config. - updateOverallBalancingState(); - - shutdownRemoved(acceptRetVal.removedChildren); - } finally { - this.resolvingAddresses = false; } - - return Status.OK; + // Calculate scale + long minWeight = Collections.min(serverWeights.values()); + double normalizedMinWeight = (double) minWeight / totalWeight; + // Scale up the number of hashes per host such that the least-weighted host gets a whole + // number of hashes on the the ring. Other hosts might not end up with whole numbers, and + // that's fine (the ring-building algorithm can handle this). This preserves the original + // implementation's behavior: when weights aren't provided, all hosts should get an equal + // number of hashes. In the case where this number exceeds the max_ring_size, it's scaled + // back down to fit. + double scale = Math.min( + Math.ceil(normalizedMinWeight * config.minRingSize) / normalizedMinWeight, + (double) config.maxRingSize); + + // Build the ring + ring = buildRing(serverWeights, totalWeight, scale); + + return super.acceptResolvedAddresses(resolvedAddresses); } @@ -213,11 +214,32 @@ protected void updateOverallBalancingState() { overallState = TRANSIENT_FAILURE; } - RingHashPicker picker = new RingHashPicker(syncContext, ring, getChildLbStates()); + // gRFC A61: if the aggregated connectivity state is TRANSIENT_FAILURE or CONNECTING and + // there are no endpoints in CONNECTING state, the ring_hash policy will choose one of + // the endpoints in IDLE state (if any) to trigger a connection attempt on + if (numReady == 0 && numTF > 0 && numConnecting == 0 && numIdle > 0) { + triggerIdleChildConnection(); + } + + RingHashPicker picker = + new RingHashPicker(syncContext, ring, getChildLbStates(), requestHashHeaderKey, random); getHelper().updateBalancingState(overallState, picker); this.currentConnectivityState = overallState; } + + /** + * Triggers a connection attempt for the first IDLE child load balancer. + */ + private void triggerIdleChildConnection() { + for (ChildLbState child : getChildLbStates()) { + if (child.getCurrentState() == ConnectivityState.IDLE) { + child.getLb().requestConnection(); + return; + } + } + } + @Override protected ChildLbState createChildLbState(Object key) { return new ChildLbState(key, lazyLbFactory); @@ -241,7 +263,7 @@ private Status validateAddrList(List addrList) { long totalWeight = 0; for (EquivalentAddressGroup eag : addrList) { - Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT); + Long weight = eag.getAttributes().get(XdsAttributes.ATTR_SERVER_WEIGHT); if (weight == null) { weight = 1L; @@ -341,21 +363,32 @@ private static final class RingHashPicker extends SubchannelPicker { // TODO(chengyuanzhang): can be more performance-friendly with // IdentityHashMap and RingEntry contains Subchannel. private final Map pickableSubchannels; // read-only + @Nullable private final Metadata.Key requestHashHeaderKey; + private final ThreadSafeRandom random; + private final boolean hasEndpointInConnectingState; private RingHashPicker( SynchronizationContext syncContext, List ring, - Collection children) { + Collection children, Metadata.Key requestHashHeaderKey, + ThreadSafeRandom random) { this.syncContext = syncContext; this.ring = ring; + this.requestHashHeaderKey = requestHashHeaderKey; + this.random = random; pickableSubchannels = new HashMap<>(children.size()); + boolean hasConnectingState = false; for (ChildLbState childLbState : children) { pickableSubchannels.put((Endpoint)childLbState.getKey(), new SubchannelView(childLbState, childLbState.getCurrentState())); + if (childLbState.getCurrentState() == CONNECTING) { + hasConnectingState = true; + } } + this.hasEndpointInConnectingState = hasConnectingState; } // Find the ring entry with hash next to (clockwise) the RPC's hash (binary search). - private int getTargetIndex(Long requestHash) { + private int getTargetIndex(long requestHash) { if (ring.size() <= 1) { return 0; } @@ -381,38 +414,80 @@ private int getTargetIndex(Long requestHash) { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - Long requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY); - if (requestHash == null) { - return PickResult.withError(RPC_HASH_NOT_FOUND); + // Determine request hash. + boolean usingRandomHash = false; + long requestHash; + if (requestHashHeaderKey == null) { + // Set by the xDS config selector. + Long rpcHashFromCallOptions = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY); + if (rpcHashFromCallOptions == null) { + return PickResult.withError(RPC_HASH_NOT_FOUND); + } + requestHash = rpcHashFromCallOptions; + } else { + Iterable headerValues = args.getHeaders().getAll(requestHashHeaderKey); + if (headerValues != null) { + requestHash = hashFunc.hashAsciiString(Joiner.on(",").join(headerValues)); + } else { + requestHash = random.nextLong(); + usingRandomHash = true; + } } int targetIndex = getTargetIndex(requestHash); - // Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, we ignore - // all TF subchannels and find the first ring entry in READY, CONNECTING or IDLE. If - // CONNECTING or IDLE we return a pick with no results. Additionally, if that entry is in - // IDLE, we initiate a connection. - for (int i = 0; i < ring.size(); i++) { - int index = (targetIndex + i) % ring.size(); - SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); - ChildLbState childLbState = subchannelView.childLbState; - - if (subchannelView.connectivityState == READY) { - return childLbState.getCurrentPicker().pickSubchannel(args); + if (!usingRandomHash) { + // Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, we ignore + // all TF subchannels and find the first ring entry in READY, CONNECTING or IDLE. If + // CONNECTING or IDLE we return a pick with no results. Additionally, if that entry is in + // IDLE, we initiate a connection. + for (int i = 0; i < ring.size(); i++) { + int index = (targetIndex + i) % ring.size(); + SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); + ChildLbState childLbState = subchannelView.childLbState; + + if (subchannelView.connectivityState == READY) { + return childLbState.getCurrentPicker().pickSubchannel(args); + } + + // RPCs can be buffered if the next subchannel is pending (per A62). Otherwise, RPCs + // are failed unless there is a READY connection. + if (subchannelView.connectivityState == CONNECTING) { + return PickResult.withNoResult(); + } + + if (subchannelView.connectivityState == IDLE) { + syncContext.execute(() -> { + if (childLbState.getCurrentState() == IDLE) { + childLbState.getLb().requestConnection(); + } + }); + + return PickResult.withNoResult(); // Indicates that this should be retried after backoff + } } - - // RPCs can be buffered if the next subchannel is pending (per A62). Otherwise, RPCs - // are failed unless there is a READY connection. - if (subchannelView.connectivityState == CONNECTING) { - return PickResult.withNoResult(); + } else { + // Using a random hash. Find and use the first READY ring entry, triggering at most one + // entry to attempt connection. + boolean requestedConnection = hasEndpointInConnectingState; + for (int i = 0; i < ring.size(); i++) { + int index = (targetIndex + i) % ring.size(); + SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); + ChildLbState childLbState = subchannelView.childLbState; + if (subchannelView.connectivityState == READY) { + return childLbState.getCurrentPicker().pickSubchannel(args); + } + if (!requestedConnection && subchannelView.connectivityState == IDLE) { + syncContext.execute(() -> { + if (childLbState.getCurrentState() == IDLE) { + childLbState.getLb().requestConnection(); + } + }); + requestedConnection = true; + } } - - if (subchannelView.connectivityState == IDLE) { - syncContext.execute(() -> { - childLbState.getLb().requestConnection(); - }); - - return PickResult.withNoResult(); // Indicates that this should be retried after backoff + if (requestedConnection) { + return PickResult.withNoResult(); } } @@ -460,13 +535,32 @@ public int compareTo(RingEntry entry) { static final class RingHashConfig { final long minRingSize; final long maxRingSize; + final String requestHashHeader; - RingHashConfig(long minRingSize, long maxRingSize) { + RingHashConfig(long minRingSize, long maxRingSize, String requestHashHeader) { checkArgument(minRingSize > 0, "minRingSize <= 0"); checkArgument(maxRingSize > 0, "maxRingSize <= 0"); checkArgument(minRingSize <= maxRingSize, "minRingSize > maxRingSize"); + checkNotNull(requestHashHeader); this.minRingSize = minRingSize; this.maxRingSize = maxRingSize; + this.requestHashHeader = requestHashHeader; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof RingHashConfig)) { + return false; + } + RingHashConfig that = (RingHashConfig) o; + return this.minRingSize == that.minRingSize + && this.maxRingSize == that.maxRingSize + && Objects.equals(this.requestHashHeader, that.requestHashHeader); + } + + @Override + public int hashCode() { + return Objects.hash(minRingSize, maxRingSize, requestHashHeader); } @Override @@ -474,6 +568,7 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("minRingSize", minRingSize) .add("maxRingSize", maxRingSize) + .add("requestHashHeader", requestHashHeader) .toString(); } } diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java index dad79384569..bb4f8de5a5f 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java @@ -24,6 +24,7 @@ import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import io.grpc.xds.RingHashOptions; @@ -81,6 +82,10 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( Map rawLoadBalancingPolicyConfig) { Long minRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "minRingSize"); Long maxRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "maxRingSize"); + String requestHashHeader = ""; + if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY", false)) { + requestHashHeader = JsonUtil.getString(rawLoadBalancingPolicyConfig, "requestHashHeader"); + } long maxRingSizeCap = RingHashOptions.getRingSizeCap(); if (minRingSize == null) { minRingSize = DEFAULT_MIN_RING_SIZE; @@ -88,6 +93,9 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( if (maxRingSize == null) { maxRingSize = DEFAULT_MAX_RING_SIZE; } + if (requestHashHeader == null) { + requestHashHeader = ""; + } if (minRingSize > maxRingSizeCap) { minRingSize = maxRingSizeCap; } @@ -96,8 +104,9 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( } if (minRingSize <= 0 || maxRingSize <= 0 || minRingSize > maxRingSize) { return ConfigOrError.fromError(Status.UNAVAILABLE.withDescription( - "Invalid 'mingRingSize'/'maxRingSize'")); + "Invalid 'minRingSize'/'maxRingSize'")); } - return ConfigOrError.fromConfig(new RingHashConfig(minRingSize, maxRingSize)); + return ConfigOrError.fromConfig( + new RingHashConfig(minRingSize, maxRingSize, requestHashHeader)); } } diff --git a/xds/src/main/java/io/grpc/xds/RouterFilter.java b/xds/src/main/java/io/grpc/xds/RouterFilter.java index 7f1adf86a6d..504c4213149 100644 --- a/xds/src/main/java/io/grpc/xds/RouterFilter.java +++ b/xds/src/main/java/io/grpc/xds/RouterFilter.java @@ -17,19 +17,12 @@ package io.grpc.xds; import com.google.protobuf.Message; -import io.grpc.ClientInterceptor; -import io.grpc.LoadBalancer.PickSubchannelArgs; -import io.grpc.ServerInterceptor; -import io.grpc.xds.Filter.ClientInterceptorBuilder; -import io.grpc.xds.Filter.ServerInterceptorBuilder; -import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; /** * Router filter implementation. Currently this filter does not parse any field in the config. */ -enum RouterFilter implements Filter, ClientInterceptorBuilder, ServerInterceptorBuilder { - INSTANCE; +final class RouterFilter implements Filter { + private static final RouterFilter INSTANCE = new RouterFilter(); static final String TYPE_URL = "type.googleapis.com/envoy.extensions.filters.http.router.v3.Router"; @@ -37,7 +30,7 @@ enum RouterFilter implements Filter, ClientInterceptorBuilder, ServerInterceptor static final FilterConfig ROUTER_CONFIG = new FilterConfig() { @Override public String typeUrl() { - return RouterFilter.TYPE_URL; + return TYPE_URL; } @Override @@ -46,33 +39,38 @@ public String toString() { } }; - @Override - public String[] typeUrls() { - return new String[] { TYPE_URL }; - } + static final class Provider implements Filter.Provider { + @Override + public String[] typeUrls() { + return new String[]{TYPE_URL}; + } - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - return ConfigOrError.fromConfig(ROUTER_CONFIG); - } + @Override + public boolean isClientFilter() { + return true; + } - @Override - public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { - return ConfigOrError.fromError("Router Filter should not have override config"); - } + @Override + public boolean isServerFilter() { + return true; + } - @Nullable - @Override - public ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, - ScheduledExecutorService scheduler) { - return null; - } + @Override + public RouterFilter newInstance(String name) { + return INSTANCE; + } - @Nullable - @Override - public ServerInterceptor buildServerInterceptor( - FilterConfig config, @Nullable Filter.FilterConfig overrideConfig) { - return null; + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + return ConfigOrError.fromConfig(ROUTER_CONFIG); + } + + @Override + public ConfigOrError parseFilterConfigOverride( + Message rawProtoMessage) { + return ConfigOrError.fromError("Router Filter should not have override config"); + } } + + private RouterFilter() {} } diff --git a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java index c9195896d82..45c379244af 100644 --- a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java @@ -17,10 +17,12 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.CallCredentials; +import io.grpc.MetricRecorder; import io.grpc.internal.ExponentialBackoffPolicy; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; @@ -35,11 +37,9 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -51,54 +51,65 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory { private static final boolean LOG_XDS_NODE_ID = Boolean.parseBoolean( System.getenv("GRPC_LOG_XDS_NODE_ID")); private static final Logger log = Logger.getLogger(XdsClientImpl.class.getName()); + private static final ExponentialBackoffPolicy.Provider BACKOFF_POLICY_PROVIDER = + new ExponentialBackoffPolicy.Provider(); + @Nullable private final Bootstrapper bootstrapper; private final Object lock = new Object(); - private final AtomicReference> bootstrapOverride = new AtomicReference<>(); private final Map> targetToXdsClientMap = new ConcurrentHashMap<>(); SharedXdsClientPoolProvider() { - this(new GrpcBootstrapperImpl()); + this(null); } @VisibleForTesting - SharedXdsClientPoolProvider(Bootstrapper bootstrapper) { - this.bootstrapper = checkNotNull(bootstrapper, "bootstrapper"); + SharedXdsClientPoolProvider(@Nullable Bootstrapper bootstrapper) { + this.bootstrapper = bootstrapper; } static SharedXdsClientPoolProvider getDefaultProvider() { return SharedXdsClientPoolProviderHolder.instance; } - @Override - public void setBootstrapOverride(Map bootstrap) { - bootstrapOverride.set(bootstrap); - } - @Override @Nullable public ObjectPool get(String target) { return targetToXdsClientMap.get(target); } + @Deprecated + public ObjectPool getOrCreate( + String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials) + throws XdsInitializationException { + BootstrapInfo bootstrapInfo; + if (bootstrapper != null) { + bootstrapInfo = bootstrapper.bootstrap(); + } else { + bootstrapInfo = GrpcBootstrapperImpl.defaultBootstrap(); + } + return getOrCreate(target, bootstrapInfo, metricRecorder, transportCallCredentials); + } + @Override - public ObjectPool getOrCreate(String target) throws XdsInitializationException { + public ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder) { + return getOrCreate(target, bootstrapInfo, metricRecorder, null); + } + + public ObjectPool getOrCreate( + String target, + BootstrapInfo bootstrapInfo, + MetricRecorder metricRecorder, + CallCredentials transportCallCredentials) { ObjectPool ref = targetToXdsClientMap.get(target); if (ref == null) { synchronized (lock) { ref = targetToXdsClientMap.get(target); if (ref == null) { - BootstrapInfo bootstrapInfo; - Map rawBootstrap = bootstrapOverride.get(); - if (rawBootstrap != null) { - bootstrapInfo = bootstrapper.bootstrap(rawBootstrap); - } else { - bootstrapInfo = bootstrapper.bootstrap(); - } - if (bootstrapInfo.servers().isEmpty()) { - throw new XdsInitializationException("No xDS server provided"); - } - ref = new RefCountedXdsClientObjectPool(bootstrapInfo, target); + ref = + new RefCountedXdsClientObjectPool( + bootstrapInfo, target, metricRecorder, transportCallCredentials); targetToXdsClientMap.put(target, ref); } } @@ -111,19 +122,18 @@ public ImmutableList getTargets() { return ImmutableList.copyOf(targetToXdsClientMap.keySet()); } - private static class SharedXdsClientPoolProviderHolder { private static final SharedXdsClientPoolProvider instance = new SharedXdsClientPoolProvider(); } @ThreadSafe @VisibleForTesting - static class RefCountedXdsClientObjectPool implements ObjectPool { + class RefCountedXdsClientObjectPool implements ObjectPool { - private static final ExponentialBackoffPolicy.Provider BACKOFF_POLICY_PROVIDER = - new ExponentialBackoffPolicy.Provider(); private final BootstrapInfo bootstrapInfo; private final String target; // The target associated with the xDS client. + private final MetricRecorder metricRecorder; + private final CallCredentials transportCallCredentials; private final Object lock = new Object(); @GuardedBy("lock") private ScheduledExecutorService scheduler; @@ -131,11 +141,25 @@ static class RefCountedXdsClientObjectPool implements ObjectPool { private XdsClient xdsClient; @GuardedBy("lock") private int refCount; + @GuardedBy("lock") + private XdsClientMetricReporterImpl metricReporter; + + @VisibleForTesting + RefCountedXdsClientObjectPool( + BootstrapInfo bootstrapInfo, String target, MetricRecorder metricRecorder) { + this(bootstrapInfo, target, metricRecorder, null); + } @VisibleForTesting - RefCountedXdsClientObjectPool(BootstrapInfo bootstrapInfo, String target) { - this.bootstrapInfo = checkNotNull(bootstrapInfo); + RefCountedXdsClientObjectPool( + BootstrapInfo bootstrapInfo, + String target, + MetricRecorder metricRecorder, + CallCredentials transportCallCredentials) { + this.bootstrapInfo = checkNotNull(bootstrapInfo, "bootstrapInfo"); this.target = target; + this.metricRecorder = checkNotNull(metricRecorder, "metricRecorder"); + this.transportCallCredentials = transportCallCredentials; } @Override @@ -146,15 +170,21 @@ public XdsClient getObject() { log.log(Level.INFO, "xDS node ID: {0}", bootstrapInfo.node().getId()); } scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE); - xdsClient = new XdsClientImpl( - DEFAULT_XDS_TRANSPORT_FACTORY, - bootstrapInfo, - scheduler, - BACKOFF_POLICY_PROVIDER, - GrpcUtil.STOPWATCH_SUPPLIER, - TimeProvider.SYSTEM_TIME_PROVIDER, - MessagePrinter.INSTANCE, - new TlsContextManagerImpl(bootstrapInfo)); + metricReporter = new XdsClientMetricReporterImpl(metricRecorder, target); + GrpcXdsTransportFactory xdsTransportFactory = + new GrpcXdsTransportFactory(transportCallCredentials); + xdsClient = + new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + scheduler, + BACKOFF_POLICY_PROVIDER, + GrpcUtil.STOPWATCH_SUPPLIER, + TimeProvider.SYSTEM_TIME_PROVIDER, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + metricReporter); + metricReporter.setXdsClient(xdsClient); } refCount++; return xdsClient; @@ -168,7 +198,14 @@ public XdsClient returnObject(Object object) { if (refCount == 0) { xdsClient.shutdown(); xdsClient = null; + metricReporter.close(); + metricReporter = null; + targetToXdsClientMap.remove(target); scheduler = SharedResourceHolder.release(GrpcUtil.TIMER_SERVICE, scheduler); + } else if (refCount < 0) { + assert false; // We want our tests to fail + log.log(Level.SEVERE, "Negative reference count. File a bug", new Exception()); + refCount = 0; } return null; } diff --git a/xds/src/main/java/io/grpc/xds/StructOrError.java b/xds/src/main/java/io/grpc/xds/StructOrError.java new file mode 100644 index 00000000000..14f008d191e --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/StructOrError.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import javax.annotation.Nullable; + +/** An object or a String error. */ +final class StructOrError { + + /** + * Returns a {@link StructOrError} for the successfully converted data object. + */ + public static StructOrError fromStruct(T struct) { + return new StructOrError<>(struct); + } + + /** + * Returns a {@link StructOrError} for the failure to convert the data object. + */ + public static StructOrError fromError(String errorDetail) { + return new StructOrError<>(errorDetail); + } + + private final String errorDetail; + private final T struct; + + private StructOrError(T struct) { + this.struct = checkNotNull(struct, "struct"); + this.errorDetail = null; + } + + private StructOrError(String errorDetail) { + this.struct = null; + this.errorDetail = checkNotNull(errorDetail, "errorDetail"); + } + + /** + * Returns struct if exists, otherwise null. + */ + @VisibleForTesting + @Nullable + public T getStruct() { + return struct; + } + + /** + * Returns error detail if exists, otherwise null. + */ + @VisibleForTesting + @Nullable + public String getErrorDetail() { + return errorDetail; + } +} + diff --git a/xds/src/main/java/io/grpc/xds/VirtualHost.java b/xds/src/main/java/io/grpc/xds/VirtualHost.java index d9f93dd3a07..5cc979984c6 100644 --- a/xds/src/main/java/io/grpc/xds/VirtualHost.java +++ b/xds/src/main/java/io/grpc/xds/VirtualHost.java @@ -166,29 +166,34 @@ abstract static class RouteAction { @Nullable abstract RetryPolicy retryPolicy(); + abstract boolean autoHostRewrite(); + static RouteAction forCluster( String cluster, List hashPolicies, @Nullable Long timeoutNano, - @Nullable RetryPolicy retryPolicy) { + @Nullable RetryPolicy retryPolicy, boolean autoHostRewrite) { checkNotNull(cluster, "cluster"); - return RouteAction.create(hashPolicies, timeoutNano, cluster, null, null, retryPolicy); + return RouteAction.create(hashPolicies, timeoutNano, cluster, null, null, retryPolicy, + autoHostRewrite); } static RouteAction forWeightedClusters( List weightedClusters, List hashPolicies, - @Nullable Long timeoutNano, @Nullable RetryPolicy retryPolicy) { + @Nullable Long timeoutNano, @Nullable RetryPolicy retryPolicy, boolean autoHostRewrite) { checkNotNull(weightedClusters, "weightedClusters"); checkArgument(!weightedClusters.isEmpty(), "empty cluster list"); return RouteAction.create( - hashPolicies, timeoutNano, null, weightedClusters, null, retryPolicy); + hashPolicies, timeoutNano, null, weightedClusters, null, retryPolicy, autoHostRewrite); } static RouteAction forClusterSpecifierPlugin( NamedPluginConfig namedConfig, List hashPolicies, @Nullable Long timeoutNano, - @Nullable RetryPolicy retryPolicy) { + @Nullable RetryPolicy retryPolicy, + boolean autoHostRewrite) { checkNotNull(namedConfig, "namedConfig"); - return RouteAction.create(hashPolicies, timeoutNano, null, null, namedConfig, retryPolicy); + return RouteAction.create(hashPolicies, timeoutNano, null, null, namedConfig, retryPolicy, + autoHostRewrite); } private static RouteAction create( @@ -197,26 +202,29 @@ private static RouteAction create( @Nullable String cluster, @Nullable List weightedClusters, @Nullable NamedPluginConfig namedConfig, - @Nullable RetryPolicy retryPolicy) { + @Nullable RetryPolicy retryPolicy, + boolean autoHostRewrite) { return new AutoValue_VirtualHost_Route_RouteAction( ImmutableList.copyOf(hashPolicies), timeoutNano, cluster, weightedClusters == null ? null : ImmutableList.copyOf(weightedClusters), namedConfig, - retryPolicy); + retryPolicy, + autoHostRewrite); } @AutoValue abstract static class ClusterWeight { abstract String name(); - abstract int weight(); + abstract long weight(); abstract ImmutableMap filterConfigOverrides(); static ClusterWeight create( - String name, int weight, Map filterConfigOverrides) { + String name, long weight, Map filterConfigOverrides) { + checkArgument(weight >= 0, "weight must not be negative"); return new AutoValue_VirtualHost_Route_RouteAction_ClusterWeight( name, weight, ImmutableMap.copyOf(filterConfigOverrides)); } diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index 73764c63c80..a8b7e120cca 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -40,6 +40,7 @@ import io.grpc.services.MetricReport; import io.grpc.util.ForwardingSubchannel; import io.grpc.util.MultiChildLoadBalancer; +import io.grpc.xds.internal.MetricReportUtils; import io.grpc.xds.orca.OrcaOobUtil; import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener; import io.grpc.xds.orca.OrcaPerRequestUtil; @@ -48,6 +49,8 @@ import java.util.Collection; import java.util.HashSet; import java.util.List; +import java.util.Objects; +import java.util.OptionalDouble; import java.util.Random; import java.util.Set; import java.util.concurrent.ScheduledExecutorService; @@ -102,32 +105,44 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer { private final long infTime; private final Ticker ticker; private String locality = ""; + private String backendService = ""; private SubchannelPicker currentPicker = new FixedResultPicker(PickResult.withNoResult()); // The metric instruments are only registered once and shared by all instances of this LB. static { MetricInstrumentRegistry metricInstrumentRegistry = MetricInstrumentRegistry.getDefaultRegistry(); - RR_FALLBACK_COUNTER = metricInstrumentRegistry.registerLongCounter("grpc.lb.wrr.rr_fallback", + RR_FALLBACK_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.lb.wrr.rr_fallback", "EXPERIMENTAL. Number of scheduler updates in which there were not enough endpoints " + "with valid weight, which caused the WRR policy to fall back to RR behavior", - "{update}", Lists.newArrayList("grpc.target"), Lists.newArrayList("grpc.lb.locality"), + "{update}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.locality", "grpc.lb.backend_service"), false); ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER = metricInstrumentRegistry.registerLongCounter( - "grpc.lb.wrr.endpoint_weight_not_yet_usable", "EXPERIMENTAL. Number of endpoints " - + "from each scheduler update that don't yet have usable weight information", - "{endpoint}", Lists.newArrayList("grpc.target"), Lists.newArrayList("grpc.lb.locality"), + "grpc.lb.wrr.endpoint_weight_not_yet_usable", + "EXPERIMENTAL. Number of endpoints from each scheduler update that don't yet have usable " + + "weight information", + "{endpoint}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.locality", "grpc.lb.backend_service"), false); ENDPOINT_WEIGHT_STALE_COUNTER = metricInstrumentRegistry.registerLongCounter( "grpc.lb.wrr.endpoint_weight_stale", "EXPERIMENTAL. Number of endpoints from each scheduler update whose latest weight is " - + "older than the expiration period", "{endpoint}", Lists.newArrayList("grpc.target"), - Lists.newArrayList("grpc.lb.locality"), false); + + "older than the expiration period", + "{endpoint}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.locality", "grpc.lb.backend_service"), + false); ENDPOINT_WEIGHTS_HISTOGRAM = metricInstrumentRegistry.registerDoubleHistogram( "grpc.lb.wrr.endpoint_weights", "EXPERIMENTAL. The histogram buckets will be endpoint weight ranges.", - "{weight}", Lists.newArrayList(), Lists.newArrayList("grpc.target"), - Lists.newArrayList("grpc.lb.locality"), + "{weight}", + Lists.newArrayList(), + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.locality", "grpc.lb.backend_service"), false); } @@ -168,33 +183,26 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { } else { this.locality = ""; } + String backendService + = resolvedAddresses.getAttributes().get(NameResolver.ATTR_BACKEND_SERVICE); + if (backendService != null) { + this.backendService = backendService; + } else { + this.backendService = ""; + } config = - (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - AcceptResolvedAddrRetVal acceptRetVal; - try { - resolvingAddresses = true; - acceptRetVal = acceptResolvedAddressesInternal(resolvedAddresses); - if (!acceptRetVal.status.isOk()) { - return acceptRetVal.status; - } + (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - if (weightUpdateTimer != null && weightUpdateTimer.isPending()) { - weightUpdateTimer.cancel(); - } - updateWeightTask.run(); + if (weightUpdateTimer != null && weightUpdateTimer.isPending()) { + weightUpdateTimer.cancel(); + } + updateWeightTask.run(); - createAndApplyOrcaListeners(); + Status status = super.acceptResolvedAddresses(resolvedAddresses); - // Must update channel picker before return so that new RPCs will not be routed to deleted - // clusters and resolver can remove them in service config. - updateOverallBalancingState(); + createAndApplyOrcaListeners(); - shutdownRemoved(acceptRetVal.removedChildren); - } finally { - resolvingAddresses = false; - } - - return acceptRetVal.status; + return status; } /** @@ -230,7 +238,8 @@ protected void updateOverallBalancingState() { private SubchannelPicker createReadyPicker(Collection activeList) { WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), - config.enableOobLoadReport, config.errorUtilizationPenalty, sequence); + config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, + config.metricNamesForComputingUtilization); updateWeight(picker); return picker; } @@ -246,7 +255,7 @@ private void updateWeight(WeightedRoundRobinPicker picker) { helper.getMetricRecorder() .recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight, ImmutableList.of(helper.getChannelTarget()), - ImmutableList.of(locality)); + ImmutableList.of(locality, backendService)); newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; } @@ -254,18 +263,19 @@ private void updateWeight(WeightedRoundRobinPicker picker) { helper.getMetricRecorder() .addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(), ImmutableList.of(helper.getChannelTarget()), - ImmutableList.of(locality)); + ImmutableList.of(locality, backendService)); } if (notYetUsableEndpoints.get() > 0) { helper.getMetricRecorder() .addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(), - ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality)); + ImmutableList.of(helper.getChannelTarget()), + ImmutableList.of(locality, backendService)); } boolean weightsEffective = picker.updateWeight(newWeights); if (!weightsEffective) { helper.getMetricRecorder() .addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()), - ImmutableList.of(locality)); + ImmutableList.of(locality, backendService)); } } @@ -318,12 +328,16 @@ public void addSubchannel(WrrSubchannel wrrSubchannel) { subchannels.add(wrrSubchannel); } - public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty) { + public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty, + ImmutableList metricNamesForComputingUtilization) { if (orcaReportListener != null - && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty) { + && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty + && orcaReportListener.metricNamesForComputingUtilization + .equals(metricNamesForComputingUtilization)) { return orcaReportListener; } - orcaReportListener = new OrcaReportListener(errorUtilizationPenalty); + orcaReportListener = + new OrcaReportListener(errorUtilizationPenalty, metricNamesForComputingUtilization); return orcaReportListener; } @@ -348,18 +362,19 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener { private final float errorUtilizationPenalty; + private final ImmutableList metricNamesForComputingUtilization; - OrcaReportListener(float errorUtilizationPenalty) { + OrcaReportListener(float errorUtilizationPenalty, + ImmutableList metricNamesForComputingUtilization) { this.errorUtilizationPenalty = errorUtilizationPenalty; + this.metricNamesForComputingUtilization = metricNamesForComputingUtilization; } @Override public void onLoadReport(MetricReport report) { + double utilization = getUtilization(report, metricNamesForComputingUtilization); + double newWeight = 0; - // Prefer application utilization and fallback to CPU utilization if unset. - double utilization = - report.getApplicationUtilization() > 0 ? report.getApplicationUtilization() - : report.getCpuUtilization(); if (utilization > 0 && report.getQps() > 0) { double penalty = 0; if (report.getEps() > 0 && errorUtilizationPenalty > 0) { @@ -376,6 +391,40 @@ public void onLoadReport(MetricReport report) { lastUpdated = ticker.nanoTime(); weight = newWeight; } + + /** + * Returns the utilization value computed from the specified metric names. If the custom + * metrics are present and valid, the maximum of the custom metrics is returned. Otherwise, + * if application utilization is > 0, it is returned. If neither are present, the CPU + * utilization is returned. + */ + private double getUtilization(MetricReport report, ImmutableList metricNames) { + OptionalDouble customUtil = getCustomMetricUtilization(report, metricNames); + if (customUtil.isPresent()) { + return customUtil.getAsDouble(); + } + double appUtil = report.getApplicationUtilization(); + if (appUtil > 0) { + return appUtil; + } + return report.getCpuUtilization(); + } + + /** + * Returns the maximum utilization value among the specified metric names. + * Returns OptionalDouble.empty() if NONE of the specified metrics are present in the report, + * or if all present metrics are NaN. + * Returns OptionalDouble.of(maxUtil) if at least one non-NaN metric is present. + */ + private OptionalDouble getCustomMetricUtilization(MetricReport report, + ImmutableList metricNames) { + return metricNames.stream() + .map(name -> MetricReportUtils.getMetric(report, name)) + .filter(OptionalDouble::isPresent) + .mapToDouble(OptionalDouble::getAsDouble) + .filter(d -> !Double.isNaN(d) && d > 0) + .max(); + } } } @@ -396,10 +445,10 @@ private void createAndApplyOrcaListeners() { for (WrrSubchannel weightedSubchannel : wChild.subchannels) { if (config.enableOobLoadReport) { OrcaOobUtil.setListener(weightedSubchannel, - wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty), + wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty, + config.metricNamesForComputingUtilization), OrcaOobUtil.OrcaReportingConfig.newBuilder() - .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS) - .build()); + .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS).build()); } else { OrcaOobUtil.setListener(weightedSubchannel, null, null); } @@ -466,7 +515,8 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker { private volatile StaticStrideScheduler scheduler; WeightedRoundRobinPicker(List children, boolean enableOobLoadReport, - float errorUtilizationPenalty, AtomicInteger sequence) { + float errorUtilizationPenalty, AtomicInteger sequence, + ImmutableList metricNamesForComputingUtilization) { checkNotNull(children, "children"); Preconditions.checkArgument(!children.isEmpty(), "empty child list"); this.children = children; @@ -475,7 +525,8 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker { for (ChildLbState child : children) { WeightedChildLbState wChild = (WeightedChildLbState) child; pickers.add(wChild.getCurrentPicker()); - reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty)); + reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty, + metricNamesForComputingUtilization)); } this.pickers = pickers; this.reportListeners = reportListeners; @@ -501,12 +552,15 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { if (subchannel == null) { return pickResult; } + + subchannel = ((WrrSubchannel) subchannel).delegate(); if (!enableOobLoadReport) { - return PickResult.withSubchannel(subchannel, - OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( - reportListeners.get(pick))); + return pickResult.copyWithSubchannel(subchannel) + .copyWithStreamTracerFactory( + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( + reportListeners.get(pick))); } else { - return PickResult.withSubchannel(subchannel); + return pickResult.copyWithSubchannel(subchannel); } } @@ -713,32 +767,57 @@ static final class WeightedRoundRobinLoadBalancerConfig { final long oobReportingPeriodNanos; final long weightUpdatePeriodNanos; final float errorUtilizationPenalty; + final ImmutableList metricNamesForComputingUtilization; public static Builder newBuilder() { return new Builder(); } private WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos, - long weightExpirationPeriodNanos, - boolean enableOobLoadReport, - long oobReportingPeriodNanos, - long weightUpdatePeriodNanos, - float errorUtilizationPenalty) { + long weightExpirationPeriodNanos, boolean enableOobLoadReport, long oobReportingPeriodNanos, + long weightUpdatePeriodNanos, float errorUtilizationPenalty, + ImmutableList metricNamesForComputingUtilization) { this.blackoutPeriodNanos = blackoutPeriodNanos; this.weightExpirationPeriodNanos = weightExpirationPeriodNanos; this.enableOobLoadReport = enableOobLoadReport; this.oobReportingPeriodNanos = oobReportingPeriodNanos; this.weightUpdatePeriodNanos = weightUpdatePeriodNanos; this.errorUtilizationPenalty = errorUtilizationPenalty; + this.metricNamesForComputingUtilization = metricNamesForComputingUtilization; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof WeightedRoundRobinLoadBalancerConfig)) { + return false; + } + WeightedRoundRobinLoadBalancerConfig that = (WeightedRoundRobinLoadBalancerConfig) o; + return this.blackoutPeriodNanos == that.blackoutPeriodNanos + && this.weightExpirationPeriodNanos == that.weightExpirationPeriodNanos + && this.enableOobLoadReport == that.enableOobLoadReport + && this.oobReportingPeriodNanos == that.oobReportingPeriodNanos + && this.weightUpdatePeriodNanos == that.weightUpdatePeriodNanos + // Float.compare considers NaNs equal + && Float.compare(this.errorUtilizationPenalty, that.errorUtilizationPenalty) == 0 + && Objects.equals(this.metricNamesForComputingUtilization, + that.metricNamesForComputingUtilization); + } + + @Override + public int hashCode() { + return Objects.hash(blackoutPeriodNanos, weightExpirationPeriodNanos, enableOobLoadReport, + oobReportingPeriodNanos, weightUpdatePeriodNanos, errorUtilizationPenalty, + metricNamesForComputingUtilization); } static final class Builder { long blackoutPeriodNanos = 10_000_000_000L; // 10s - long weightExpirationPeriodNanos = 180_000_000_000L; //3min + long weightExpirationPeriodNanos = 180_000_000_000L; // 3min boolean enableOobLoadReport = false; long oobReportingPeriodNanos = 10_000_000_000L; // 10s long weightUpdatePeriodNanos = 1_000_000_000L; // 1s float errorUtilizationPenalty = 1.0F; + ImmutableList metricNamesForComputingUtilization = ImmutableList.of(); private Builder() { @@ -776,10 +855,17 @@ Builder setErrorUtilizationPenalty(float errorUtilizationPenalty) { return this; } + Builder setMetricNamesForComputingUtilization( + List metricNamesForComputingUtilization) { + this.metricNamesForComputingUtilization = + ImmutableList.copyOf(metricNamesForComputingUtilization); + return this; + } + WeightedRoundRobinLoadBalancerConfig build() { return new WeightedRoundRobinLoadBalancerConfig(blackoutPeriodNanos, - weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos, - weightUpdatePeriodNanos, errorUtilizationPenalty); + weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos, + weightUpdatePeriodNanos, errorUtilizationPenalty, metricNamesForComputingUtilization); } } } diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java index 433ea34b857..e17b8764a6c 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java @@ -24,8 +24,10 @@ import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; +import java.util.List; import java.util.Map; /** @@ -73,14 +75,16 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map rawConfig) { Long blackoutPeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "blackoutPeriod"); Long weightExpirationPeriodNanos = - JsonUtil.getStringAsDuration(rawConfig, "weightExpirationPeriod"); + JsonUtil.getStringAsDuration(rawConfig, "weightExpirationPeriod"); Long oobReportingPeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "oobReportingPeriod"); Boolean enableOobLoadReport = JsonUtil.getBoolean(rawConfig, "enableOobLoadReport"); Long weightUpdatePeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "weightUpdatePeriod"); Float errorUtilizationPenalty = JsonUtil.getNumberAsFloat(rawConfig, "errorUtilizationPenalty"); + List metricNamesForComputingUtilization = JsonUtil.getListOfStrings(rawConfig, + LoadBalancerConfigFactory.METRIC_NAMES_FOR_COMPUTING_UTILIZATION); WeightedRoundRobinLoadBalancerConfig.Builder configBuilder = - WeightedRoundRobinLoadBalancerConfig.newBuilder(); + WeightedRoundRobinLoadBalancerConfig.newBuilder(); if (blackoutPeriodNanos != null) { configBuilder.setBlackoutPeriodNanos(blackoutPeriodNanos); } @@ -102,6 +106,10 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map rawC if (errorUtilizationPenalty != null) { configBuilder.setErrorUtilizationPenalty(errorUtilizationPenalty); } + if (metricNamesForComputingUtilization != null + && GrpcUtil.getFlag("GRPC_EXPERIMENTAL_WRR_CUSTOM_METRICS", false)) { + configBuilder.setMetricNamesForComputingUtilization(metricNamesForComputingUtilization); + } return ConfigOrError.fromConfig(configBuilder.build()); } } diff --git a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java index 0a11f118057..9468a9daf9d 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java @@ -87,8 +87,9 @@ public Status acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresse } } targets = newTargets; + Status status = Status.OK; for (String targetName : targets.keySet()) { - childBalancers.get(targetName).handleResolvedAddresses( + Status newStatus = childBalancers.get(targetName).acceptResolvedAddresses( resolvedAddresses.toBuilder() .setAddresses(AddressFilter.filter(resolvedAddresses.getAddresses(), targetName)) .setLoadBalancingPolicyConfig(targets.get(targetName).childConfig) @@ -96,6 +97,9 @@ public Status acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresse .set(CHILD_NAME, targetName) .build()) .build()); + if (!newStatus.isOk()) { + status = newStatus; + } } // Cleanup removed targets. @@ -108,7 +112,7 @@ public Status acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresse childBalancers.keySet().retainAll(targets.keySet()); childHelpers.keySet().retainAll(targets.keySet()); updateOverallBalancingState(); - return Status.OK; + return status; } @Override @@ -124,6 +128,8 @@ public void handleNameResolutionError(Status error) { } @Override + @Deprecated + @SuppressWarnings("InlineMeSuggester") public boolean canHandleEmptyAddressListFromNameResolution() { return true; } diff --git a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java index 55f33fb11aa..15318693aca 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java @@ -25,6 +25,7 @@ import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; import io.grpc.util.GracefulSwitchLoadBalancer; import java.util.LinkedHashMap; @@ -99,9 +100,10 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { ConfigOrError childConfig = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( JsonUtil.getListOfObjects(rawWeightedTarget, "childPolicy"), lbRegistry); if (childConfig.getError() != null) { - return ConfigOrError.fromError(Status.INTERNAL - .withDescription("Could not parse weighted_target's child policy:" + name) - .withCause(childConfig.getError().asRuntimeException())); + return ConfigOrError.fromError(GrpcUtil.statusWithDetails( + Status.Code.INTERNAL, + "Could not parse weighted_target's child policy: " + name, + childConfig.getError())); } parsedChildConfigs.put(name, new WeightedPolicySelection(weight, childConfig.getConfig())); } diff --git a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java index 46d2443d36a..1a12412f923 100644 --- a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java @@ -74,8 +74,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { Map localityWeights = new HashMap<>(); for (EquivalentAddressGroup eag : resolvedAddresses.getAddresses()) { Attributes eagAttrs = eag.getAttributes(); - String locality = eagAttrs.get(InternalXdsAttributes.ATTR_LOCALITY_NAME); - Integer localityWeight = eagAttrs.get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT); + String locality = eagAttrs.get(EquivalentAddressGroup.ATTR_LOCALITY_NAME); + Integer localityWeight = eagAttrs.get(XdsAttributes.ATTR_LOCALITY_WEIGHT); if (locality == null) { Status unavailableStatus = Status.UNAVAILABLE.withDescription( @@ -113,12 +113,10 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { Object switchConfig = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( lbRegistry.getProvider(WEIGHTED_TARGET_POLICY_NAME), new WeightedTargetConfig(weightedPolicySelections)); - switchLb.handleResolvedAddresses( + return switchLb.acceptResolvedAddresses( resolvedAddresses.toBuilder() .setLoadBalancingPolicyConfig(switchConfig) .build()); - - return Status.OK; } @Override diff --git a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancerProvider.java index 384831b8a05..3693df9208a 100644 --- a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancerProvider.java @@ -23,6 +23,7 @@ import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.WrrLocalityLoadBalancer.WrrLocalityConfig; @@ -62,9 +63,10 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { ConfigOrError childConfig = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( JsonUtil.getListOfObjects(rawConfig, "childPolicy")); if (childConfig.getError() != null) { - return ConfigOrError.fromError(Status.INTERNAL - .withDescription("Failed to parse child policy in wrr_locality LB policy: " + rawConfig) - .withCause(childConfig.getError().asRuntimeException())); + return ConfigOrError.fromError(GrpcUtil.statusWithDetails( + Status.Code.INTERNAL, + "Failed to parse child policy in wrr_locality LB policy", + childConfig.getError())); } return ConfigOrError.fromConfig(new WrrLocalityConfig(childConfig.getConfig())); } catch (RuntimeException e) { diff --git a/xds/src/main/java/io/grpc/xds/XdsAttributes.java b/xds/src/main/java/io/grpc/xds/XdsAttributes.java new file mode 100644 index 00000000000..d3fe8d4619c --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsAttributes.java @@ -0,0 +1,104 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; +import io.grpc.Grpc; +import io.grpc.InternalEquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; +import io.grpc.xds.client.Locality; +import io.grpc.xds.client.XdsClient; + +/** + * Attributes used for xDS implementation. + */ +final class XdsAttributes { + /** + * Attribute key for passing around the XdsClient object pool across NameResolver/LoadBalancers. + */ + @NameResolver.ResolutionResultAttr + static final Attributes.Key XDS_CLIENT = + Attributes.Key.create("io.grpc.xds.XdsAttributes.xdsClient"); + + /** + * Attribute key for passing around the latest XdsConfig across NameResolver/LoadBalancers. + */ + @NameResolver.ResolutionResultAttr + static final Attributes.Key XDS_CONFIG = + Attributes.Key.create("io.grpc.xds.XdsAttributes.xdsConfig"); + + + /** + * Attribute key for passing around the XdsDependencyManager across NameResolver/LoadBalancers. + */ + @NameResolver.ResolutionResultAttr + static final Attributes.Key + XDS_CLUSTER_SUBSCRIPT_REGISTRY = + Attributes.Key.create("io.grpc.xds.XdsAttributes.xdsConfig.XdsClusterSubscriptionRegistry"); + + /** + * Attribute key for obtaining the global provider that provides atomics for aggregating + * outstanding RPCs sent to each cluster. + */ + @NameResolver.ResolutionResultAttr + static final Attributes.Key CALL_COUNTER_PROVIDER = + Attributes.Key.create("io.grpc.xds.XdsAttributes.callCounterProvider"); + + /** + * Map from localities to their weights. + */ + @NameResolver.ResolutionResultAttr + static final Attributes.Key ATTR_LOCALITY_WEIGHT = + Attributes.Key.create("io.grpc.xds.XdsAttributes.localityWeight"); + + /** + * Name of the cluster that provides this EquivalentAddressGroup. + */ + @EquivalentAddressGroup.Attr + public static final Attributes.Key ATTR_CLUSTER_NAME = + Attributes.Key.create("io.grpc.xds.XdsAttributes.clusterName"); + + /** + * The locality that this EquivalentAddressGroup is in. + */ + @EquivalentAddressGroup.Attr + static final Attributes.Key ATTR_LOCALITY = + Attributes.Key.create("io.grpc.xds.XdsAttributes.locality"); + + /** + * Endpoint weight for load balancing purposes. + */ + @EquivalentAddressGroup.Attr + static final Attributes.Key ATTR_SERVER_WEIGHT = InternalEquivalentAddressGroup.ATTR_WEIGHT; + + /** + * Filter chain match for network filters. + */ + @Grpc.TransportAttr + static final Attributes.Key + ATTR_FILTER_CHAIN_SELECTOR_MANAGER = Attributes.Key.create( + "io.grpc.xds.XdsAttributes.filterChainSelectorManager"); + + /** Grace time to use when draining. Null for an infinite grace time. */ + @Grpc.TransportAttr + static final Attributes.Key ATTR_DRAIN_GRACE_NANOS = + Attributes.Key.create("io.grpc.xds.XdsAttributes.drainGraceTime"); + + private XdsAttributes() {} +} diff --git a/xds/src/main/java/io/grpc/xds/XdsClientMetricReporterImpl.java b/xds/src/main/java/io/grpc/xds/XdsClientMetricReporterImpl.java new file mode 100644 index 00000000000..5cfba11c065 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsClientMetricReporterImpl.java @@ -0,0 +1,233 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.ListenableFuture; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.LongGaugeMetricInstrument; +import io.grpc.MetricInstrumentRegistry; +import io.grpc.MetricRecorder; +import io.grpc.MetricRecorder.BatchCallback; +import io.grpc.MetricRecorder.BatchRecorder; +import io.grpc.MetricRecorder.Registration; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceMetadata; +import io.grpc.xds.client.XdsClient.ResourceMetadata.ResourceMetadataStatus; +import io.grpc.xds.client.XdsClient.ServerConnectionCallback; +import io.grpc.xds.client.XdsClientMetricReporter; +import io.grpc.xds.client.XdsResourceType; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * XdsClientMetricReporter implementation. + */ +final class XdsClientMetricReporterImpl implements XdsClientMetricReporter { + + private static final Logger logger = Logger.getLogger( + XdsClientMetricReporterImpl.class.getName()); + private static final LongCounterMetricInstrument SERVER_FAILURE_COUNTER; + private static final LongCounterMetricInstrument RESOURCE_UPDATES_VALID_COUNTER; + private static final LongCounterMetricInstrument RESOURCE_UPDATES_INVALID_COUNTER; + private static final LongGaugeMetricInstrument CONNECTED_GAUGE; + private static final LongGaugeMetricInstrument RESOURCES_GAUGE; + + private final MetricRecorder metricRecorder; + private final String target; + @Nullable + private Registration gaugeRegistration = null; + + static { + MetricInstrumentRegistry metricInstrumentRegistry + = MetricInstrumentRegistry.getDefaultRegistry(); + SERVER_FAILURE_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.xds_client.server_failure", + "EXPERIMENTAL. A counter of xDS servers going from healthy to unhealthy. A server goes" + + " unhealthy when we have a connectivity failure or when the ADS stream fails without" + + " seeing a response message, as per gRFC A57.", "{failure}", + Arrays.asList("grpc.target", "grpc.xds.server"), Collections.emptyList(), false); + RESOURCE_UPDATES_VALID_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.xds_client.resource_updates_valid", + "EXPERIMENTAL. A counter of resources received that were considered valid. The counter will" + + " be incremented even for resources that have not changed.", "{resource}", + Arrays.asList("grpc.target", "grpc.xds.server", "grpc.xds.resource_type"), + Collections.emptyList(), false); + RESOURCE_UPDATES_INVALID_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.xds_client.resource_updates_invalid", + "EXPERIMENTAL. A counter of resources received that were considered invalid.", "{resource}", + Arrays.asList("grpc.target", "grpc.xds.server", "grpc.xds.resource_type"), + Collections.emptyList(), false); + CONNECTED_GAUGE = metricInstrumentRegistry.registerLongGauge("grpc.xds_client.connected", + "EXPERIMENTAL. Whether or not the xDS client currently has a working ADS stream to the xDS" + + " server. For a given server, this will be set to 1 when the stream is initially" + + " created. It will be set to 0 when we have a connectivity failure or when the ADS" + + " stream fails without seeing a response message, as per gRFC A57. Once set to 0, it" + + " will be reset to 1 when we receive the first response on an ADS stream.", "{bool}", + Arrays.asList("grpc.target", "grpc.xds.server"), Collections.emptyList(), false); + RESOURCES_GAUGE = metricInstrumentRegistry.registerLongGauge("grpc.xds_client.resources", + "EXPERIMENTAL. Number of xDS resources.", "{resource}", + Arrays.asList("grpc.target", "grpc.xds.authority", "grpc.xds.cache_state", + "grpc.xds.resource_type"), Collections.emptyList(), false); + } + + XdsClientMetricReporterImpl(MetricRecorder metricRecorder, String target) { + this.metricRecorder = metricRecorder; + this.target = target; + } + + @Override + public void reportResourceUpdates(long validResourceCount, long invalidResourceCount, + String xdsServer, String resourceType) { + metricRecorder.addLongCounter(RESOURCE_UPDATES_VALID_COUNTER, validResourceCount, + Arrays.asList(target, xdsServer, resourceType), Collections.emptyList()); + metricRecorder.addLongCounter(RESOURCE_UPDATES_INVALID_COUNTER, invalidResourceCount, + Arrays.asList(target, xdsServer, resourceType), Collections.emptyList()); + } + + @Override + public void reportServerFailure(long serverFailure, String xdsServer) { + metricRecorder.addLongCounter(SERVER_FAILURE_COUNTER, serverFailure, + Arrays.asList(target, xdsServer), Collections.emptyList()); + } + + void setXdsClient(XdsClient xdsClient) { + assert gaugeRegistration == null; + // register gauge here + this.gaugeRegistration = metricRecorder.registerBatchCallback(new BatchCallback() { + @Override + public void accept(BatchRecorder recorder) { + reportCallbackMetrics(recorder, xdsClient); + } + }, CONNECTED_GAUGE, RESOURCES_GAUGE); + } + + void close() { + if (gaugeRegistration != null) { + gaugeRegistration.close(); + gaugeRegistration = null; + } + } + + void reportCallbackMetrics(BatchRecorder recorder, XdsClient xdsClient) { + MetricReporterCallback callback = new MetricReporterCallback(recorder, target); + try { + Future reportServerConnectionsCompleted = xdsClient.reportServerConnections(callback); + + ListenableFuture, Map>> + getResourceMetadataCompleted = xdsClient.getSubscribedResourcesMetadataSnapshot(); + + Map, Map> metadataByType = + getResourceMetadataCompleted.get(10, TimeUnit.SECONDS); + + computeAndReportResourceCounts(metadataByType, callback); + + // Normally this shouldn't take long, but adding a timeout to avoid indefinite blocking + Void unused = reportServerConnectionsCompleted.get(5, TimeUnit.SECONDS); + } catch (ExecutionException | TimeoutException | InterruptedException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); // re-set the current thread's interruption state + } + logger.log(Level.WARNING, "Failed to report gauge metrics", e); + } + } + + private void computeAndReportResourceCounts( + Map, Map> metadataByType, + MetricReporterCallback callback) { + for (Map.Entry, Map> metadataByTypeEntry : + metadataByType.entrySet()) { + XdsResourceType type = metadataByTypeEntry.getKey(); + Map resources = metadataByTypeEntry.getValue(); + + Map> resourceCountsByAuthorityAndState = new HashMap<>(); + for (Map.Entry resourceEntry : resources.entrySet()) { + String resourceName = resourceEntry.getKey(); + ResourceMetadata metadata = resourceEntry.getValue(); + String authority = XdsClient.getAuthorityFromResourceName(resourceName); + String cacheState = cacheStateFromResourceStatus(metadata.getStatus(), metadata.isCached()); + resourceCountsByAuthorityAndState + .computeIfAbsent(authority, k -> new HashMap<>()) + .merge(cacheState, 1L, Long::sum); + } + + // Report metrics + for (Map.Entry> authorityEntry + : resourceCountsByAuthorityAndState.entrySet()) { + String authority = authorityEntry.getKey(); + Map stateCounts = authorityEntry.getValue(); + + for (Map.Entry stateEntry : stateCounts.entrySet()) { + String cacheState = stateEntry.getKey(); + Long count = stateEntry.getValue(); + + callback.reportResourceCountGauge(count, authority, cacheState, type.typeUrl()); + } + } + } + } + + private static String cacheStateFromResourceStatus(ResourceMetadataStatus metadataStatus, + boolean isResourceCached) { + switch (metadataStatus) { + case REQUESTED: + return "requested"; + case DOES_NOT_EXIST: + return "does_not_exist"; + case ACKED: + return "acked"; + case NACKED: + return isResourceCached ? "nacked_but_cached" : "nacked"; + default: + return "unknown"; + } + } + + @VisibleForTesting + static final class MetricReporterCallback implements ServerConnectionCallback { + private final BatchRecorder recorder; + private final String target; + + MetricReporterCallback(BatchRecorder recorder, String target) { + this.recorder = recorder; + this.target = target; + } + + void reportResourceCountGauge(long resourceCount, String authority, String cacheState, + String resourceType) { + // authority = #old, for non-xdstp resource names + recorder.recordLongGauge(RESOURCES_GAUGE, resourceCount, + Arrays.asList(target, authority == null ? "#old" : authority, cacheState, resourceType), + Collections.emptyList()); + } + + @Override + public void reportServerConnectionGauge(boolean isConnected, String xdsServer) { + recorder.recordLongGauge(CONNECTED_GAUGE, isConnected ? 1 : 0, + Arrays.asList(target, xdsServer), Collections.emptyList()); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java b/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java index 313eb675116..6df8d566a7a 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java @@ -16,20 +16,19 @@ package io.grpc.xds; +import io.grpc.MetricRecorder; import io.grpc.internal.ObjectPool; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsInitializationException; import java.util.List; -import java.util.Map; import javax.annotation.Nullable; interface XdsClientPoolFactory { - void setBootstrapOverride(Map bootstrap); - @Nullable ObjectPool get(String target); - ObjectPool getOrCreate(String target) throws XdsInitializationException; + ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder); List getTargets(); } diff --git a/xds/src/main/java/io/grpc/xds/XdsClusterResource.java b/xds/src/main/java/io/grpc/xds/XdsClusterResource.java index c6340156d49..10efc47be47 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClusterResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsClusterResource.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.xds.client.Bootstrapper.ServerInfo; +import static io.grpc.xds.client.LoadStatsManager2.isEnabledOrcaLrsPropagation; import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; @@ -34,18 +35,23 @@ import io.envoyproxy.envoy.config.cluster.v3.Cluster; import io.envoyproxy.envoy.config.core.v3.RoutingPriority; import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.core.v3.TransportSocket; import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.extensions.transport_sockets.http_11_proxy.v3.Http11ProxyUpstreamTransport; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.ServiceConfigUtil; import io.grpc.internal.ServiceConfigUtil.LbConfig; import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.client.BackendMetricPropagation; import io.grpc.xds.client.XdsClient.ResourceUpdate; import io.grpc.xds.client.XdsResourceType; +import io.grpc.xds.internal.security.CommonTlsContextUtil; import java.util.List; import java.util.Locale; import java.util.Set; @@ -56,12 +62,20 @@ class XdsClusterResource extends XdsResourceType { static boolean enableLeastRequest = !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) ? Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) - : Boolean.parseBoolean(System.getProperty("io.grpc.xds.experimentalEnableLeastRequest")); + : Boolean.parseBoolean( + System.getProperty("io.grpc.xds.experimentalEnableLeastRequest", "true")); + @VisibleForTesting + public static boolean enableSystemRootCerts = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_SYSTEM_ROOT_CERTS", true); + static boolean isEnabledXdsHttpConnect = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_HTTP_CONNECT", false); @VisibleForTesting static final String AGGREGATE_CLUSTER_TYPE_NAME = "envoy.clusters.aggregate"; static final String ADS_TYPE_URL_CDS = "type.googleapis.com/envoy.config.cluster.v3.Cluster"; + private static final String TYPE_URL_CLUSTER_CONFIG = + "type.googleapis.com/envoy.extensions.clusters.aggregate.v3.ClusterConfig"; private static final String TYPE_URL_UPSTREAM_TLS_CONTEXT = "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext"; private static final String TYPE_URL_UPSTREAM_TLS_CONTEXT_V2 = @@ -157,13 +171,26 @@ static CdsUpdate processCluster(Cluster cluster, lbConfig.getPolicyName()).parseLoadBalancingPolicyConfig( lbConfig.getRawConfigValue()); if (configOrError.getError() != null) { - throw new ResourceInvalidException(structOrError.getErrorDetail()); + throw new ResourceInvalidException( + "Failed to parse lb config for cluster '" + cluster.getName() + "': " + + configOrError.getError()); } updateBuilder.lbPolicyConfig(lbPolicyConfig); updateBuilder.filterMetadata( ImmutableMap.copyOf(cluster.getMetadata().getFilterMetadataMap())); + try { + MetadataRegistry registry = MetadataRegistry.getInstance(); + ImmutableMap parsedFilterMetadata = + registry.parseMetadata(cluster.getMetadata()); + updateBuilder.parsedMetadata(parsedFilterMetadata); + } catch (ResourceInvalidException e) { + throw new ResourceInvalidException( + "Failed to parse xDS filter metadata for cluster '" + cluster.getName() + "': " + + e.getMessage(), e); + } + return updateBuilder.build(); } @@ -183,6 +210,10 @@ private static StructOrError parseAggregateCluster(Cluster cl } catch (InvalidProtocolBufferException e) { return StructOrError.fromError("Cluster " + clusterName + ": malformed ClusterConfig: " + e); } + if (clusterConfig.getClustersList().isEmpty()) { + return StructOrError.fromError("Cluster " + clusterName + + ": aggregate ClusterConfig.clusters must not be empty"); + } return StructOrError.fromStruct(CdsUpdate.forAggregate( clusterName, clusterConfig.getClustersList())); } @@ -194,6 +225,13 @@ private static StructOrError parseNonAggregateCluster( Long maxConcurrentRequests = null; UpstreamTlsContext upstreamTlsContext = null; OutlierDetection outlierDetection = null; + boolean isHttp11ProxyAvailable = false; + BackendMetricPropagation backendMetricPropagation = null; + + if (isEnabledOrcaLrsPropagation) { + backendMetricPropagation = BackendMetricPropagation.fromMetricSpecs( + cluster.getLrsReportEndpointMetricsList()); + } if (cluster.hasLrsServer()) { if (!cluster.getLrsServer().hasSelf()) { return StructOrError.fromError( @@ -208,7 +246,7 @@ private static StructOrError parseNonAggregateCluster( continue; } if (threshold.hasMaxRequests()) { - maxConcurrentRequests = (long) threshold.getMaxRequests().getValue(); + maxConcurrentRequests = Integer.toUnsignedLong(threshold.getMaxRequests().getValue()); } } } @@ -216,17 +254,43 @@ private static StructOrError parseNonAggregateCluster( return StructOrError.fromError("Cluster " + clusterName + ": transport-socket-matches not supported."); } - if (cluster.hasTransportSocket()) { - if (!TRANSPORT_SOCKET_NAME_TLS.equals(cluster.getTransportSocket().getName())) { - return StructOrError.fromError("transport-socket with name " - + cluster.getTransportSocket().getName() + " not supported."); + boolean hasTransportSocket = cluster.hasTransportSocket(); + TransportSocket transportSocket = cluster.getTransportSocket(); + + if (hasTransportSocket && !TRANSPORT_SOCKET_NAME_TLS.equals(transportSocket.getName()) + && !(isEnabledXdsHttpConnect && transportSocket.getTypedConfig().is( + Http11ProxyUpstreamTransport.class))) { + return StructOrError.fromError( + "transport-socket with name " + transportSocket.getName() + " not supported."); + } + + if (hasTransportSocket && isEnabledXdsHttpConnect && transportSocket.getTypedConfig().is( + Http11ProxyUpstreamTransport.class)) { + isHttp11ProxyAvailable = true; + try { + Http11ProxyUpstreamTransport wrappedTransportSocket = transportSocket + .getTypedConfig().unpack(io.envoyproxy.envoy.extensions.transport_sockets + .http_11_proxy.v3.Http11ProxyUpstreamTransport.class); + hasTransportSocket = wrappedTransportSocket.hasTransportSocket(); + transportSocket = wrappedTransportSocket.getTransportSocket(); + } catch (InvalidProtocolBufferException e) { + return StructOrError.fromError( + "Cluster " + clusterName + ": malformed Http11ProxyUpstreamTransport: " + e); + } catch (ClassCastException e) { + return StructOrError.fromError( + "Cluster " + clusterName + + ": invalid transport_socket type in Http11ProxyUpstreamTransport"); } + } + + if (hasTransportSocket && TRANSPORT_SOCKET_NAME_TLS.equals(transportSocket.getName())) { try { upstreamTlsContext = UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext( validateUpstreamTlsContext( - unpackCompatibleType(cluster.getTransportSocket().getTypedConfig(), - io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext.class, - TYPE_URL_UPSTREAM_TLS_CONTEXT, TYPE_URL_UPSTREAM_TLS_CONTEXT_V2), + unpackCompatibleType(transportSocket.getTypedConfig(), + io.envoyproxy.envoy.extensions + .transport_sockets.tls.v3.UpstreamTlsContext.class, + TYPE_URL_UPSTREAM_TLS_CONTEXT, TYPE_URL_UPSTREAM_TLS_CONTEXT_V2), certProviderInstances)); } catch (InvalidProtocolBufferException | ResourceInvalidException e) { return StructOrError.fromError( @@ -264,9 +328,10 @@ private static StructOrError parseNonAggregateCluster( return StructOrError.fromError( "EDS service_name must be set when Cluster resource has an xdstp name"); } + return StructOrError.fromStruct(CdsUpdate.forEds( clusterName, edsServiceName, lrsServerInfo, maxConcurrentRequests, upstreamTlsContext, - outlierDetection)); + outlierDetection, isHttp11ProxyAvailable, backendMetricPropagation)); } else if (type.equals(Cluster.DiscoveryType.LOGICAL_DNS)) { if (!cluster.hasLoadAssignment()) { return StructOrError.fromError( @@ -301,7 +366,8 @@ private static StructOrError parseNonAggregateCluster( String dnsHostName = String.format( Locale.US, "%s:%d", socketAddress.getAddress(), socketAddress.getPortValue()); return StructOrError.fromStruct(CdsUpdate.forLogicalDns( - clusterName, dnsHostName, lrsServerInfo, maxConcurrentRequests, upstreamTlsContext)); + clusterName, dnsHostName, lrsServerInfo, maxConcurrentRequests, + upstreamTlsContext, isHttp11ProxyAvailable, backendMetricPropagation)); } return StructOrError.fromError( "Cluster " + clusterName + ": unsupported built-in discovery type: " + type); @@ -396,15 +462,6 @@ static void validateCommonTlsContext( throw new ResourceInvalidException( "common-tls-context with validation_context_sds_secret_config is not supported"); } - if (commonTlsContext.hasValidationContextCertificateProvider()) { - throw new ResourceInvalidException( - "common-tls-context with validation_context_certificate_provider is not supported"); - } - if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { - throw new ResourceInvalidException( - "common-tls-context with validation_context_certificate_provider_instance is not" - + " supported"); - } String certInstanceName = getIdentityCertInstanceName(commonTlsContext); if (certInstanceName == null) { if (server) { @@ -419,10 +476,6 @@ static void validateCommonTlsContext( throw new ResourceInvalidException( "tls_certificate_provider_instance is unset"); } - if (commonTlsContext.hasTlsCertificateCertificateProvider()) { - throw new ResourceInvalidException( - "tls_certificate_provider_instance is unset"); - } } else if (certProviderInstances == null || !certProviderInstances.contains(certInstanceName)) { throw new ResourceInvalidException( "CertificateProvider instance name '" + certInstanceName @@ -430,9 +483,11 @@ static void validateCommonTlsContext( } String rootCaInstanceName = getRootCertInstanceName(commonTlsContext); if (rootCaInstanceName == null) { - if (!server) { + if (!server && (!enableSystemRootCerts + || !CommonTlsContextUtil.isUsingSystemRootCerts(commonTlsContext))) { throw new ResourceInvalidException( - "ca_certificate_provider_instance is required in upstream-tls-context"); + "ca_certificate_provider_instance or system_root_certs is required in " + + "upstream-tls-context"); } } else { if (certProviderInstances == null || !certProviderInstances.contains(rootCaInstanceName)) { @@ -449,7 +504,9 @@ static void validateCommonTlsContext( .getDefaultValidationContext(); } if (certificateValidationContext != null) { - if (certificateValidationContext.getMatchSubjectAltNamesCount() > 0 && server) { + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names + int matchSubjectAltNamesCount = certificateValidationContext.getMatchSubjectAltNamesCount(); + if (matchSubjectAltNamesCount > 0 && server) { throw new ResourceInvalidException( "match_subject_alt_names only allowed in upstream_tls_context"); } @@ -480,10 +537,13 @@ static void validateCommonTlsContext( private static String getIdentityCertInstanceName(CommonTlsContext commonTlsContext) { if (commonTlsContext.hasTlsCertificateProviderInstance()) { return commonTlsContext.getTlsCertificateProviderInstance().getInstanceName(); - } else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { - return commonTlsContext.getTlsCertificateCertificateProviderInstance().getInstanceName(); } - return null; + // Fall back to deprecated field (field 11) for backward compatibility with Istio + @SuppressWarnings("deprecation") + String instanceName = commonTlsContext.hasTlsCertificateCertificateProviderInstance() + ? commonTlsContext.getTlsCertificateCertificateProviderInstance().getInstanceName() + : null; + return instanceName; } private static String getRootCertInstanceName(CommonTlsContext commonTlsContext) { @@ -500,10 +560,16 @@ private static String getRootCertInstanceName(CommonTlsContext commonTlsContext) .hasCaCertificateProviderInstance()) { return combinedCertificateValidationContext.getDefaultValidationContext() .getCaCertificateProviderInstance().getInstanceName(); - } else if (combinedCertificateValidationContext - .hasValidationContextCertificateProviderInstance()) { - return combinedCertificateValidationContext - .getValidationContextCertificateProviderInstance().getInstanceName(); + } + // Fall back to deprecated field (field 4) in CombinedValidationContext + @SuppressWarnings("deprecation") + String instanceName = combinedCertificateValidationContext + .hasValidationContextCertificateProviderInstance() + ? combinedCertificateValidationContext.getValidationContextCertificateProviderInstance() + .getInstanceName() + : null; + if (instanceName != null) { + return instanceName; } } return null; @@ -553,6 +619,8 @@ abstract static class CdsUpdate implements ResourceUpdate { @Nullable abstract UpstreamTlsContext upstreamTlsContext(); + abstract boolean isHttp11ProxyAvailable(); + // List of underlying clusters making of this aggregate cluster. // Only valid for AGGREGATE cluster. @Nullable @@ -564,13 +632,21 @@ abstract static class CdsUpdate implements ResourceUpdate { abstract ImmutableMap filterMetadata(); + abstract ImmutableMap parsedMetadata(); + + @Nullable + abstract BackendMetricPropagation backendMetricPropagation(); + private static Builder newBuilder(String clusterName) { return new AutoValue_XdsClusterResource_CdsUpdate.Builder() .clusterName(clusterName) .minRingSize(0) .maxRingSize(0) .choiceCount(0) - .filterMetadata(ImmutableMap.of()); + .filterMetadata(ImmutableMap.of()) + .parsedMetadata(ImmutableMap.of()) + .isHttp11ProxyAvailable(false) + .backendMetricPropagation(null); } static Builder forAggregate(String clusterName, List prioritizedClusterNames) { @@ -583,26 +659,34 @@ static Builder forAggregate(String clusterName, List prioritizedClusterN static Builder forEds(String clusterName, @Nullable String edsServiceName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext upstreamTlsContext, - @Nullable OutlierDetection outlierDetection) { + @Nullable OutlierDetection outlierDetection, + boolean isHttp11ProxyAvailable, + BackendMetricPropagation backendMetricPropagation) { return newBuilder(clusterName) .clusterType(ClusterType.EDS) .edsServiceName(edsServiceName) .lrsServerInfo(lrsServerInfo) .maxConcurrentRequests(maxConcurrentRequests) .upstreamTlsContext(upstreamTlsContext) - .outlierDetection(outlierDetection); + .outlierDetection(outlierDetection) + .isHttp11ProxyAvailable(isHttp11ProxyAvailable) + .backendMetricPropagation(backendMetricPropagation); } static Builder forLogicalDns(String clusterName, String dnsHostName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext upstreamTlsContext) { + @Nullable UpstreamTlsContext upstreamTlsContext, + boolean isHttp11ProxyAvailable, + BackendMetricPropagation backendMetricPropagation) { return newBuilder(clusterName) .clusterType(ClusterType.LOGICAL_DNS) .dnsHostName(dnsHostName) .lrsServerInfo(lrsServerInfo) .maxConcurrentRequests(maxConcurrentRequests) - .upstreamTlsContext(upstreamTlsContext); + .upstreamTlsContext(upstreamTlsContext) + .isHttp11ProxyAvailable(isHttp11ProxyAvailable) + .backendMetricPropagation(backendMetricPropagation); } enum ClusterType { @@ -679,6 +763,8 @@ Builder leastRequestLbPolicy(Integer choiceCount) { // Private, use one of the static factory methods instead. protected abstract Builder maxConcurrentRequests(Long maxConcurrentRequests); + protected abstract Builder isHttp11ProxyAvailable(boolean isHttp11ProxyAvailable); + // Private, use one of the static factory methods instead. protected abstract Builder upstreamTlsContext(UpstreamTlsContext upstreamTlsContext); @@ -689,6 +775,11 @@ Builder leastRequestLbPolicy(Integer choiceCount) { protected abstract Builder filterMetadata(ImmutableMap filterMetadata); + protected abstract Builder parsedMetadata(ImmutableMap parsedMetadata); + + protected abstract Builder backendMetricPropagation( + BackendMetricPropagation backendMetricPropagation); + abstract CdsUpdate build(); } } diff --git a/xds/src/main/java/io/grpc/xds/XdsConfig.java b/xds/src/main/java/io/grpc/xds/XdsConfig.java new file mode 100644 index 00000000000..d184f08de55 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsConfig.java @@ -0,0 +1,265 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.grpc.StatusOr; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.XdsListenerResource.LdsUpdate; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; +import java.io.Closeable; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * Represents the xDS configuration tree for a specified Listener. + */ +final class XdsConfig { + private final LdsUpdate listener; + private final RdsUpdate route; + private final VirtualHost virtualHost; + private final ImmutableMap> clusters; + private final int hashCode; + + XdsConfig(LdsUpdate listener, RdsUpdate route, Map> clusters, + VirtualHost virtualHost) { + this(listener, route, virtualHost, ImmutableMap.copyOf(clusters)); + } + + public XdsConfig(LdsUpdate listener, RdsUpdate route, VirtualHost virtualHost, + ImmutableMap> clusters) { + this.listener = listener; + this.route = route; + this.virtualHost = virtualHost; + this.clusters = clusters; + + hashCode = Objects.hash(listener, route, virtualHost, clusters); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof XdsConfig)) { + return false; + } + + XdsConfig o = (XdsConfig) obj; + + return hashCode() == o.hashCode() && Objects.equals(listener, o.listener) + && Objects.equals(route, o.route) && Objects.equals(virtualHost, o.virtualHost) + && Objects.equals(clusters, o.clusters); + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + builder.append("XdsConfig{") + .append("\n listener=").append(listener) + .append(",\n route=").append(route) + .append(",\n virtualHost=").append(virtualHost) + .append(",\n clusters=").append(clusters) + .append("\n}"); + return builder.toString(); + } + + public LdsUpdate getListener() { + return listener; + } + + public RdsUpdate getRoute() { + return route; + } + + public VirtualHost getVirtualHost() { + return virtualHost; + } + + public ImmutableMap> getClusters() { + return clusters; + } + + static final class XdsClusterConfig { + private final String clusterName; + private final CdsUpdate clusterResource; + private final ClusterChild children; // holds details + + XdsClusterConfig(String clusterName, CdsUpdate clusterResource, ClusterChild details) { + this.clusterName = checkNotNull(clusterName, "clusterName"); + this.clusterResource = checkNotNull(clusterResource, "clusterResource"); + this.children = checkNotNull(details, "details"); + } + + @Override + public int hashCode() { + return clusterName.hashCode() + clusterResource.hashCode() + children.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof XdsClusterConfig)) { + return false; + } + XdsClusterConfig o = (XdsClusterConfig) obj; + return Objects.equals(clusterName, o.clusterName) + && Objects.equals(clusterResource, o.clusterResource) + && Objects.equals(children, o.children); + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + builder.append("XdsClusterConfig{clusterName=").append(clusterName) + .append(", clusterResource=").append(clusterResource) + .append(", children={").append(children) + .append("}"); + return builder.toString(); + } + + public String getClusterName() { + return clusterName; + } + + public CdsUpdate getClusterResource() { + return clusterResource; + } + + public ClusterChild getChildren() { + return children; + } + + interface ClusterChild {} + + /** Endpoint info for EDS and LOGICAL_DNS clusters. If there was an + * error, endpoints will be null and resolution_note will be set. + */ + static final class EndpointConfig implements ClusterChild { + private final StatusOr endpoint; + + public EndpointConfig(StatusOr endpoint) { + this.endpoint = checkNotNull(endpoint, "endpoint"); + } + + @Override + public int hashCode() { + return endpoint.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof EndpointConfig)) { + return false; + } + return Objects.equals(endpoint, ((EndpointConfig)obj).endpoint); + } + + public StatusOr getEndpoint() { + return endpoint; + } + + @Override + public String toString() { + if (endpoint.hasValue()) { + return "EndpointConfig{endpoint=" + endpoint.getValue() + "}"; + } else { + return "EndpointConfig{error=" + endpoint.getStatus() + "}"; + } + } + } + + // The list of leaf clusters for an aggregate cluster. + static final class AggregateConfig implements ClusterChild { + private final List leafNames; + + public AggregateConfig(List leafNames) { + this.leafNames = ImmutableList.copyOf(checkNotNull(leafNames, "leafNames")); + } + + public List getLeafNames() { + return leafNames; + } + + @Override + public int hashCode() { + return leafNames.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof AggregateConfig)) { + return false; + } + return Objects.equals(leafNames, ((AggregateConfig) obj).leafNames); + } + } + } + + static final class XdsConfigBuilder { + private LdsUpdate listener; + private RdsUpdate route; + private Map> clusters = new HashMap<>(); + private VirtualHost virtualHost; + + XdsConfigBuilder setListener(LdsUpdate listener) { + this.listener = checkNotNull(listener, "listener"); + return this; + } + + XdsConfigBuilder setRoute(RdsUpdate route) { + this.route = checkNotNull(route, "route"); + return this; + } + + XdsConfigBuilder addCluster(String name, StatusOr clusterConfig) { + checkNotNull(name, "name"); + checkNotNull(clusterConfig, "clusterConfig"); + clusters.put(name, clusterConfig); + return this; + } + + XdsConfigBuilder setVirtualHost(VirtualHost virtualHost) { + this.virtualHost = checkNotNull(virtualHost, "virtualHost"); + return this; + } + + XdsConfig build() { + checkNotNull(listener, "listener"); + checkNotNull(route, "route"); + checkNotNull(virtualHost, "virtualHost"); + return new XdsConfig(listener, route, clusters, virtualHost); + } + } + + public interface XdsClusterSubscriptionRegistry { + Subscription subscribeToCluster(String clusterName); + } + + public interface Subscription extends Closeable { + /** Release resources without throwing exceptions. */ + @Override + void close(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java b/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java index c33b3cd2f85..9dd77a400cd 100644 --- a/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java +++ b/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.InternalServiceProviders; import java.util.ArrayList; import java.util.Collections; @@ -28,10 +29,10 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -109,8 +110,10 @@ public static synchronized XdsCredentialsRegistry getDefaultRegistry() { if (instance == null) { List providerList = InternalServiceProviders.loadAll( XdsCredentialsProvider.class, - getHardCodedClasses(), - XdsCredentialsProvider.class.getClassLoader(), + ServiceLoader + .load(XdsCredentialsProvider.class, XdsCredentialsProvider.class.getClassLoader()) + .iterator(), + XdsCredentialsRegistry::getHardCodedClasses, new XdsCredentialsProviderPriorityAccessor()); if (providerList.isEmpty()) { logger.warning("No XdsCredsRegistry found via ServiceLoader, including for GoogleDefault, " diff --git a/xds/src/main/java/io/grpc/xds/XdsDependencyManager.java b/xds/src/main/java/io/grpc/xds/XdsDependencyManager.java new file mode 100644 index 00000000000..919836ddd9c --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsDependencyManager.java @@ -0,0 +1,949 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static io.grpc.xds.client.XdsClient.ResourceUpdate; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.SynchronizationContext; +import io.grpc.internal.RetryingNameResolver; +import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight; +import io.grpc.xds.XdsClusterResource.CdsUpdate.ClusterType; +import io.grpc.xds.XdsConfig.XdsClusterConfig.AggregateConfig; +import io.grpc.xds.XdsConfig.XdsClusterConfig.EndpointConfig; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; +import io.grpc.xds.client.Locality; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceWatcher; +import io.grpc.xds.client.XdsResourceType; +import java.net.SocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import javax.annotation.Nullable; + +/** + * This class acts as a layer of indirection between the XdsClient and the NameResolver. It + * maintains the watchers for the xds resources and when an update is received, it either requests + * referenced resources or updates the XdsConfig and notifies the XdsConfigWatcher. Each instance + * applies to a single data plane authority. + */ +final class XdsDependencyManager implements XdsConfig.XdsClusterSubscriptionRegistry { + private enum TrackedWatcherTypeEnum { + LDS, RDS, CDS, EDS, DNS + } + + private static final TrackedWatcherType LDS_TYPE = + new TrackedWatcherType<>(TrackedWatcherTypeEnum.LDS); + private static final TrackedWatcherType RDS_TYPE = + new TrackedWatcherType<>(TrackedWatcherTypeEnum.RDS); + private static final TrackedWatcherType CDS_TYPE = + new TrackedWatcherType<>(TrackedWatcherTypeEnum.CDS); + private static final TrackedWatcherType EDS_TYPE = + new TrackedWatcherType<>(TrackedWatcherTypeEnum.EDS); + private static final TrackedWatcherType> DNS_TYPE = + new TrackedWatcherType<>(TrackedWatcherTypeEnum.DNS); + + // DNS-resolved endpoints do not have the definition of the locality it belongs to, just hardcode + // to an empty locality. + private static final Locality LOGICAL_DNS_CLUSTER_LOCALITY = Locality.create("", "", ""); + + private static final int MAX_CLUSTER_RECURSION_DEPTH = 16; // Specified by gRFC A37 + + static boolean enableLogicalDns = true; + + private final String listenerName; + private final XdsClient xdsClient; + private final SynchronizationContext syncContext; + private final String dataPlaneAuthority; + private final NameResolver.Args nameResolverArgs; + private XdsConfigWatcher xdsConfigWatcher; + + private StatusOr lastUpdate = null; + private final Map> resourceWatchers = + new EnumMap<>(TrackedWatcherTypeEnum.class); + private final Set subscriptions = new HashSet<>(); + + XdsDependencyManager( + XdsClient xdsClient, + SynchronizationContext syncContext, + String dataPlaneAuthority, + String listenerName, + NameResolver.Args nameResolverArgs) { + this.listenerName = checkNotNull(listenerName, "listenerName"); + this.xdsClient = checkNotNull(xdsClient, "xdsClient"); + this.syncContext = checkNotNull(syncContext, "syncContext"); + this.dataPlaneAuthority = checkNotNull(dataPlaneAuthority, "dataPlaneAuthority"); + this.nameResolverArgs = checkNotNull(nameResolverArgs, "nameResolverArgs"); + } + + public static String toContextStr(String typeName, String resourceName) { + return typeName + " resource " + resourceName; + } + + public void start(XdsConfigWatcher xdsConfigWatcher) { + checkState(this.xdsConfigWatcher == null, "dep manager may not be restarted"); + this.xdsConfigWatcher = checkNotNull(xdsConfigWatcher, "xdsConfigWatcher"); + // start the ball rolling + syncContext.execute(() -> addWatcher(LDS_TYPE, new LdsWatcher(listenerName))); + } + + @Override + public XdsConfig.Subscription subscribeToCluster(String clusterName) { + checkState(this.xdsConfigWatcher != null, "dep manager must first be started"); + checkNotNull(clusterName, "clusterName"); + ClusterSubscription subscription = new ClusterSubscription(clusterName); + + syncContext.execute(() -> { + if (getWatchers(LDS_TYPE).isEmpty()) { + subscription.closed = true; + return; // shutdown() called + } + subscriptions.add(subscription); + addClusterWatcher(clusterName); + }); + + return subscription; + } + + /** + * For all logical dns clusters refresh their results. + */ + public void requestReresolution() { + syncContext.execute(() -> { + for (TrackedWatcher> watcher : getWatchers(DNS_TYPE).values()) { + DnsWatcher dnsWatcher = (DnsWatcher) watcher; + dnsWatcher.refresh(); + } + }); + } + + private void addWatcher( + TrackedWatcherType watcherType, XdsWatcherBase watcher) { + syncContext.throwIfNotInThisSynchronizationContext(); + XdsResourceType type = watcher.type; + String resourceName = watcher.resourceName; + + getWatchers(watcherType).put(resourceName, watcher); + xdsClient.watchXdsResource(type, resourceName, watcher, syncContext); + } + + public void shutdown() { + syncContext.execute(() -> { + for (TypeWatchers watchers : resourceWatchers.values()) { + for (TrackedWatcher watcher : watchers.watchers.values()) { + watcher.close(); + } + } + resourceWatchers.clear(); + subscriptions.clear(); + }); + } + + private void releaseSubscription(ClusterSubscription subscription) { + checkNotNull(subscription, "subscription"); + syncContext.execute(() -> { + if (subscription.closed) { + return; + } + subscription.closed = true; + if (!subscriptions.remove(subscription)) { + return; // shutdown() called + } + maybePublishConfig(); + }); + } + + /** + * Check if all resources have results, and if so, generate a new XdsConfig and send it to all + * the watchers. + */ + private void maybePublishConfig() { + syncContext.throwIfNotInThisSynchronizationContext(); + if (getWatchers(LDS_TYPE).isEmpty()) { + return; // shutdown() called + } + boolean waitingOnResource = resourceWatchers.values().stream() + .flatMap(typeWatchers -> typeWatchers.watchers.values().stream()) + .anyMatch(TrackedWatcher::missingResult); + if (waitingOnResource) { + return; + } + + StatusOr newUpdate = buildUpdate(); + if (Objects.equals(newUpdate, lastUpdate)) { + return; + } + assert newUpdate.hasValue() + || (newUpdate.getStatus().getCode() == Status.Code.UNAVAILABLE + || newUpdate.getStatus().getCode() == Status.Code.INTERNAL); + lastUpdate = newUpdate; + xdsConfigWatcher.onUpdate(lastUpdate); + } + + @VisibleForTesting + StatusOr buildUpdate() { + // Create a config and discard any watchers not accessed + WatcherTracer tracer = new WatcherTracer(resourceWatchers); + StatusOr config = buildUpdate( + tracer, listenerName, dataPlaneAuthority, subscriptions); + tracer.closeUnusedWatchers(); + return config; + } + + private static StatusOr buildUpdate( + WatcherTracer tracer, + String listenerName, + String dataPlaneAuthority, + Set subscriptions) { + XdsConfig.XdsConfigBuilder builder = new XdsConfig.XdsConfigBuilder(); + + // Iterate watchers and build the XdsConfig + + TrackedWatcher ldsWatcher + = tracer.getWatcher(LDS_TYPE, listenerName); + if (ldsWatcher == null) { + return StatusOr.fromStatus(Status.UNAVAILABLE.withDescription( + "Bug: No listener watcher found for " + listenerName)); + } + if (!ldsWatcher.getData().hasValue()) { + return StatusOr.fromStatus(ldsWatcher.getData().getStatus()); + } + XdsListenerResource.LdsUpdate ldsUpdate = ldsWatcher.getData().getValue(); + builder.setListener(ldsUpdate); + + RdsUpdateSupplier routeSource = ((LdsWatcher) ldsWatcher).getRouteSource(tracer); + if (routeSource == null) { + return StatusOr.fromStatus(Status.UNAVAILABLE.withDescription( + "Bug: No route source found for listener " + dataPlaneAuthority)); + } + StatusOr statusOrRdsUpdate = routeSource.getRdsUpdate(); + if (!statusOrRdsUpdate.hasValue()) { + return StatusOr.fromStatus(statusOrRdsUpdate.getStatus()); + } + RdsUpdate rdsUpdate = statusOrRdsUpdate.getValue(); + builder.setRoute(rdsUpdate); + + VirtualHost activeVirtualHost = + RoutingUtils.findVirtualHostForHostName(rdsUpdate.virtualHosts, dataPlaneAuthority); + if (activeVirtualHost == null) { + String error = "Failed to find virtual host matching hostname: " + dataPlaneAuthority; + return StatusOr.fromStatus(Status.UNAVAILABLE.withDescription(error)); + } + builder.setVirtualHost(activeVirtualHost); + + Map> clusters = new HashMap<>(); + LinkedHashSet ancestors = new LinkedHashSet<>(); + for (String cluster : getClusterNamesFromVirtualHost(activeVirtualHost)) { + addConfigForCluster(clusters, cluster, ancestors, tracer); + } + for (ClusterSubscription subscription : subscriptions) { + addConfigForCluster(clusters, subscription.getClusterName(), ancestors, tracer); + } + for (Map.Entry> me : clusters.entrySet()) { + builder.addCluster(me.getKey(), me.getValue()); + } + + return StatusOr.fromValue(builder.build()); + } + + private Map> getWatchers(TrackedWatcherType watcherType) { + TypeWatchers typeWatchers = resourceWatchers.get(watcherType.typeEnum); + if (typeWatchers == null) { + typeWatchers = new TypeWatchers(watcherType); + resourceWatchers.put(watcherType.typeEnum, typeWatchers); + } + assert typeWatchers.watcherType == watcherType; + @SuppressWarnings("unchecked") + TypeWatchers tTypeWatchers = (TypeWatchers) typeWatchers; + return tTypeWatchers.watchers; + } + + private static void addConfigForCluster( + Map> clusters, + String clusterName, + @SuppressWarnings("NonApiType") // Need order-preserving set for errors + LinkedHashSet ancestors, + WatcherTracer tracer) { + if (clusters.containsKey(clusterName)) { + return; + } + if (ancestors.contains(clusterName)) { + clusters.put(clusterName, StatusOr.fromStatus( + Status.INTERNAL.withDescription( + "Aggregate cluster cycle detected: " + ancestors))); + return; + } + if (ancestors.size() > MAX_CLUSTER_RECURSION_DEPTH) { + clusters.put(clusterName, StatusOr.fromStatus( + Status.INTERNAL.withDescription("Recursion limit reached: " + ancestors))); + return; + } + + CdsWatcher cdsWatcher = (CdsWatcher) tracer.getWatcher(CDS_TYPE, clusterName); + StatusOr cdsWatcherDataOr = cdsWatcher.getData(); + if (!cdsWatcherDataOr.hasValue()) { + clusters.put(clusterName, StatusOr.fromStatus(cdsWatcherDataOr.getStatus())); + return; + } + + XdsClusterResource.CdsUpdate cdsUpdate = cdsWatcherDataOr.getValue(); + XdsConfig.XdsClusterConfig.ClusterChild child; + switch (cdsUpdate.clusterType()) { + case AGGREGATE: + // Re-inserting a present element into a LinkedHashSet does not reorder the entries, so it + // preserves the priority across all aggregate clusters + LinkedHashSet leafNames = new LinkedHashSet(); + ancestors.add(clusterName); + for (String childCluster : cdsUpdate.prioritizedClusterNames()) { + addConfigForCluster(clusters, childCluster, ancestors, tracer); + StatusOr config = clusters.get(childCluster); + if (!config.hasValue()) { + // gRFC A37 says: If any of a CDS policy's watchers reports that the resource does not + // exist the policy should report that it is in TRANSIENT_FAILURE. If any of the + // watchers reports a transient ADS stream error, the policy should report that it is in + // TRANSIENT_FAILURE if it has never passed a config to its child. + // + // But there's currently disagreement about whether that is actually what we want, and + // that was not originally implemented in gRPC Java. So we're keeping Java's old + // behavior for now and only failing the "leaves" (which is a bit arbitrary for a + // cycle). + leafNames.add(childCluster); + continue; + } + XdsConfig.XdsClusterConfig.ClusterChild children = config.getValue().getChildren(); + if (children instanceof AggregateConfig) { + leafNames.addAll(((AggregateConfig) children).getLeafNames()); + } else { + leafNames.add(childCluster); + } + } + ancestors.remove(clusterName); + + child = new AggregateConfig(ImmutableList.copyOf(leafNames)); + break; + case EDS: + TrackedWatcher edsWatcher = + tracer.getWatcher(EDS_TYPE, cdsWatcher.getEdsServiceName()); + if (edsWatcher != null) { + child = new EndpointConfig(edsWatcher.getData()); + } else { + child = new EndpointConfig(StatusOr.fromStatus(Status.INTERNAL.withDescription( + "EDS resource not found for cluster " + clusterName))); + } + break; + case LOGICAL_DNS: + if (enableLogicalDns) { + TrackedWatcher> dnsWatcher = + tracer.getWatcher(DNS_TYPE, cdsUpdate.dnsHostName()); + child = new EndpointConfig(dnsToEdsUpdate(dnsWatcher.getData(), cdsUpdate.dnsHostName())); + } else { + child = new EndpointConfig(StatusOr.fromStatus( + Status.INTERNAL.withDescription("Logical DNS in dependency manager unsupported"))); + } + break; + default: + child = new EndpointConfig(StatusOr.fromStatus(Status.UNAVAILABLE.withDescription( + "Unknown type in cluster " + clusterName + " " + cdsUpdate.clusterType()))); + } + if (clusters.containsKey(clusterName)) { + // If a cycle is detected, we'll have detected it while recursing, so now there will be a key + // present. We don't want to overwrite it with a non-error value. + return; + } + clusters.put(clusterName, StatusOr.fromValue( + new XdsConfig.XdsClusterConfig(clusterName, cdsUpdate, child))); + } + + private static StatusOr dnsToEdsUpdate( + StatusOr> dnsData, String dnsHostName) { + if (!dnsData.hasValue()) { + return StatusOr.fromStatus(dnsData.getStatus()); + } + + List addresses = new ArrayList<>(); + for (EquivalentAddressGroup eag : dnsData.getValue()) { + addresses.addAll(eag.getAddresses()); + } + EquivalentAddressGroup eag = new EquivalentAddressGroup(addresses); + List endpoints = ImmutableList.of( + Endpoints.LbEndpoint.create(eag, 1, true, dnsHostName, ImmutableMap.of())); + LocalityLbEndpoints lbEndpoints = + LocalityLbEndpoints.create(endpoints, 1, 0, ImmutableMap.of()); + return StatusOr.fromValue(new XdsEndpointResource.EdsUpdate( + "fakeEds_logicalDns", + Collections.singletonMap(LOGICAL_DNS_CLUSTER_LOCALITY, lbEndpoints), + new ArrayList<>())); + } + + private void addRdsWatcher(String resourceName) { + if (getWatchers(RDS_TYPE).containsKey(resourceName)) { + return; + } + + addWatcher(RDS_TYPE, new RdsWatcher(resourceName)); + } + + private void addEdsWatcher(String edsServiceName) { + if (getWatchers(EDS_TYPE).containsKey(edsServiceName)) { + return; + } + + addWatcher(EDS_TYPE, new EdsWatcher(edsServiceName)); + } + + private void addClusterWatcher(String clusterName) { + if (getWatchers(CDS_TYPE).containsKey(clusterName)) { + return; + } + + addWatcher(CDS_TYPE, new CdsWatcher(clusterName)); + } + + private void addDnsWatcher(String dnsHostName) { + syncContext.throwIfNotInThisSynchronizationContext(); + if (getWatchers(DNS_TYPE).containsKey(dnsHostName)) { + return; + } + + DnsWatcher watcher = new DnsWatcher(dnsHostName, nameResolverArgs); + getWatchers(DNS_TYPE).put(dnsHostName, watcher); + watcher.start(); + } + + private void updateRoutes(List virtualHosts) { + VirtualHost virtualHost = + RoutingUtils.findVirtualHostForHostName(virtualHosts, dataPlaneAuthority); + Set newClusters = getClusterNamesFromVirtualHost(virtualHost); + newClusters.forEach((cluster) -> addClusterWatcher(cluster)); + } + + private String nodeInfo() { + return " nodeID: " + xdsClient.getBootstrapInfo().node().getId(); + } + + private static Set getClusterNamesFromVirtualHost(VirtualHost virtualHost) { + if (virtualHost == null) { + return Collections.emptySet(); + } + + // Get all cluster names to which requests can be routed through the virtual host. + Set clusters = new HashSet<>(); + for (VirtualHost.Route route : virtualHost.routes()) { + VirtualHost.Route.RouteAction action = route.routeAction(); + if (action == null) { + continue; + } + if (action.cluster() != null) { + clusters.add(action.cluster()); + } else if (action.weightedClusters() != null) { + for (ClusterWeight weighedCluster : action.weightedClusters()) { + clusters.add(weighedCluster.name()); + } + } + } + + return clusters; + } + + private static NameResolver createNameResolver( + String dnsHostName, + NameResolver.Args nameResolverArgs) { + URI uri; + try { + uri = new URI("dns", "", "/" + dnsHostName, null); + } catch (URISyntaxException e) { + return new FailingNameResolver( + Status.INTERNAL.withDescription("Bug, invalid URI creation: " + dnsHostName) + .withCause(e)); + } + + NameResolverProvider provider = + nameResolverArgs.getNameResolverRegistry().getProviderForScheme("dns"); + if (provider == null) { + return new FailingNameResolver( + Status.INTERNAL.withDescription("Could not find dns name resolver")); + } + + NameResolver bareResolver = provider.newNameResolver(uri, nameResolverArgs); + if (bareResolver == null) { + return new FailingNameResolver( + Status.INTERNAL.withDescription("DNS name resolver provider returned null: " + uri)); + } + return RetryingNameResolver.wrap(bareResolver, nameResolverArgs); + } + + private static class TypeWatchers { + // Key is resource name + final Map> watchers = new HashMap<>(); + final TrackedWatcherType watcherType; + + TypeWatchers(TrackedWatcherType watcherType) { + this.watcherType = checkNotNull(watcherType, "watcherType"); + } + } + + public interface XdsConfigWatcher { + /** + * An updated XdsConfig or RPC-safe Status. The status code will be either UNAVAILABLE or + * INTERNAL. + */ + void onUpdate(StatusOr config); + } + + private final class ClusterSubscription implements XdsConfig.Subscription { + private final String clusterName; + boolean closed; // Accessed from syncContext + + public ClusterSubscription(String clusterName) { + this.clusterName = checkNotNull(clusterName, "clusterName"); + } + + String getClusterName() { + return clusterName; + } + + @Override + public void close() { + releaseSubscription(this); + } + } + + /** State for tracing garbage collector. */ + private static final class WatcherTracer { + private final Map> resourceWatchers; + private final Map> usedWatchers; + + public WatcherTracer(Map> resourceWatchers) { + this.resourceWatchers = resourceWatchers; + + this.usedWatchers = new EnumMap<>(TrackedWatcherTypeEnum.class); + for (Map.Entry> me : resourceWatchers.entrySet()) { + usedWatchers.put(me.getKey(), newTypeWatchers(me.getValue().watcherType)); + } + } + + private static TypeWatchers newTypeWatchers(TrackedWatcherType type) { + return new TypeWatchers(type); + } + + public TrackedWatcher getWatcher(TrackedWatcherType watcherType, String name) { + TypeWatchers typeWatchers = resourceWatchers.get(watcherType.typeEnum); + if (typeWatchers == null) { + return null; + } + assert typeWatchers.watcherType == watcherType; + @SuppressWarnings("unchecked") + TypeWatchers tTypeWatchers = (TypeWatchers) typeWatchers; + TrackedWatcher watcher = tTypeWatchers.watchers.get(name); + if (watcher == null) { + return null; + } + @SuppressWarnings("unchecked") + TypeWatchers usedTypeWatchers = (TypeWatchers) usedWatchers.get(watcherType.typeEnum); + usedTypeWatchers.watchers.put(name, watcher); + return watcher; + } + + /** Shut down unused watchers. */ + public void closeUnusedWatchers() { + boolean changed = false; // Help out the GC by preferring old objects + for (TrackedWatcherTypeEnum key : resourceWatchers.keySet()) { + TypeWatchers orig = resourceWatchers.get(key); + TypeWatchers used = usedWatchers.get(key); + for (String name : orig.watchers.keySet()) { + if (used.watchers.containsKey(name)) { + continue; + } + orig.watchers.get(name).close(); + changed = true; + } + } + if (changed) { + resourceWatchers.putAll(usedWatchers); + } + } + } + + @SuppressWarnings("UnusedTypeParameter") + private static final class TrackedWatcherType { + public final TrackedWatcherTypeEnum typeEnum; + + public TrackedWatcherType(TrackedWatcherTypeEnum typeEnum) { + this.typeEnum = checkNotNull(typeEnum, "typeEnum"); + } + } + + private interface TrackedWatcher { + @Nullable + StatusOr getData(); + + default boolean missingResult() { + return getData() == null; + } + + default boolean hasDataValue() { + StatusOr data = getData(); + return data != null && data.hasValue(); + } + + void close(); + } + + private abstract class XdsWatcherBase + implements ResourceWatcher, TrackedWatcher { + private final XdsResourceType type; + private final String resourceName; + boolean cancelled; + + @Nullable + private StatusOr data; + @Nullable + @SuppressWarnings("unused") + private Status ambientError; + + + private XdsWatcherBase(XdsResourceType type, String resourceName) { + this.type = checkNotNull(type, "type"); + this.resourceName = checkNotNull(resourceName, "resourceName"); + } + + @Override + public void onResourceChanged(StatusOr update) { + if (cancelled) { + return; + } + ambientError = null; + if (update.hasValue()) { + data = update; + subscribeToChildren(update.getValue()); + } else { + Status status = update.getStatus(); + Status translatedStatus = Status.UNAVAILABLE.withDescription( + String.format("Error retrieving %s: %s. Details: %s%s", + toContextString(), + status.getCode(), + status.getDescription() != null ? status.getDescription() : "", + nodeInfo())); + + data = StatusOr.fromStatus(translatedStatus); + } + maybePublishConfig(); + } + + @Override + public void onAmbientError(Status error) { + if (cancelled) { + return; + } + ambientError = error.withDescription( + String.format("Ambient error for %s: %s. Details: %s%s", + toContextString(), + error.getCode(), + error.getDescription() != null ? error.getDescription() : "", + nodeInfo())); + } + + protected abstract void subscribeToChildren(T update); + + @Override + public void close() { + cancelled = true; + xdsClient.cancelXdsResourceWatch(type, resourceName, this); + } + + @Override + @Nullable + public StatusOr getData() { + return data; + } + + public String toContextString() { + return toContextStr(type.typeName(), resourceName); + } + } + + private interface RdsUpdateSupplier { + StatusOr getRdsUpdate(); + } + + private class LdsWatcher extends XdsWatcherBase + implements RdsUpdateSupplier { + + private LdsWatcher(String resourceName) { + super(XdsListenerResource.getInstance(), resourceName); + } + + @Override + public void subscribeToChildren(XdsListenerResource.LdsUpdate update) { + HttpConnectionManager httpConnectionManager = update.httpConnectionManager(); + List virtualHosts; + if (httpConnectionManager == null) { + // TCP listener. Unsupported config + virtualHosts = Collections.emptyList(); // Not null, to not delegate to RDS + } else { + virtualHosts = httpConnectionManager.virtualHosts(); + } + if (virtualHosts != null) { + updateRoutes(virtualHosts); + } + + String rdsName = getRdsName(update); + if (rdsName != null) { + addRdsWatcher(rdsName); + } + } + + private String getRdsName(XdsListenerResource.LdsUpdate update) { + HttpConnectionManager httpConnectionManager = update.httpConnectionManager(); + if (httpConnectionManager == null) { + // TCP listener. Unsupported config + return null; + } + return httpConnectionManager.rdsName(); + } + + private RdsWatcher getRdsWatcher(XdsListenerResource.LdsUpdate update, WatcherTracer tracer) { + String rdsName = getRdsName(update); + if (rdsName == null) { + return null; + } + return (RdsWatcher) tracer.getWatcher(RDS_TYPE, rdsName); + } + + public RdsUpdateSupplier getRouteSource(WatcherTracer tracer) { + if (!hasDataValue()) { + return this; + } + HttpConnectionManager hcm = getData().getValue().httpConnectionManager(); + if (hcm == null) { + return this; + } + List virtualHosts = hcm.virtualHosts(); + if (virtualHosts != null) { + return this; + } + RdsWatcher rdsWatcher = getRdsWatcher(getData().getValue(), tracer); + assert rdsWatcher != null; + return rdsWatcher; + } + + @Override + public StatusOr getRdsUpdate() { + if (missingResult()) { + return StatusOr.fromStatus(Status.UNAVAILABLE.withDescription("Not yet loaded")); + } + if (!getData().hasValue()) { + return StatusOr.fromStatus(getData().getStatus()); + } + HttpConnectionManager hcm = getData().getValue().httpConnectionManager(); + if (hcm == null) { + return StatusOr.fromStatus( + Status.UNAVAILABLE.withDescription("Not an API listener" + nodeInfo())); + } + List virtualHosts = hcm.virtualHosts(); + if (virtualHosts == null) { + // Code shouldn't trigger this case, as it should be calling RdsWatcher instead. This would + // be easily implemented with getRdsWatcher().getRdsUpdate(), but getting here is likely a + // bug + return StatusOr.fromStatus(Status.INTERNAL.withDescription("Routes are in RDS, not LDS")); + } + return StatusOr.fromValue(new RdsUpdate(virtualHosts)); + } + } + + private class RdsWatcher extends XdsWatcherBase implements RdsUpdateSupplier { + + public RdsWatcher(String resourceName) { + super(XdsRouteConfigureResource.getInstance(), checkNotNull(resourceName, "resourceName")); + } + + @Override + public void subscribeToChildren(RdsUpdate update) { + updateRoutes(update.virtualHosts); + } + + @Override + public StatusOr getRdsUpdate() { + if (missingResult()) { + return StatusOr.fromStatus(Status.UNAVAILABLE.withDescription("Not yet loaded")); + } + return getData(); + } + } + + private class CdsWatcher extends XdsWatcherBase { + CdsWatcher(String resourceName) { + super(XdsClusterResource.getInstance(), checkNotNull(resourceName, "resourceName")); + } + + @Override + public void subscribeToChildren(XdsClusterResource.CdsUpdate update) { + switch (update.clusterType()) { + case EDS: + addEdsWatcher(getEdsServiceName()); + break; + case LOGICAL_DNS: + if (enableLogicalDns) { + addDnsWatcher(update.dnsHostName()); + } + break; + case AGGREGATE: + update.prioritizedClusterNames() + .forEach(name -> addClusterWatcher(name)); + break; + default: + } + } + + public String getEdsServiceName() { + XdsClusterResource.CdsUpdate cdsUpdate = getData().getValue(); + assert cdsUpdate.clusterType() == ClusterType.EDS; + String edsServiceName = cdsUpdate.edsServiceName(); + if (edsServiceName == null) { + edsServiceName = cdsUpdate.clusterName(); + } + return edsServiceName; + } + } + + private class EdsWatcher extends XdsWatcherBase { + private EdsWatcher(String resourceName) { + super(XdsEndpointResource.getInstance(), checkNotNull(resourceName, "resourceName")); + } + + @Override + public void subscribeToChildren(XdsEndpointResource.EdsUpdate update) {} + } + + private final class DnsWatcher implements TrackedWatcher> { + private final NameResolver resolver; + @Nullable + private StatusOr> data; + private boolean cancelled; + + public DnsWatcher(String dnsHostName, NameResolver.Args nameResolverArgs) { + this.resolver = createNameResolver(dnsHostName, nameResolverArgs); + } + + public void start() { + resolver.start(new NameResolverListener()); + } + + public void refresh() { + if (cancelled) { + return; + } + resolver.refresh(); + } + + @Override + @Nullable + public StatusOr> getData() { + return data; + } + + @Override + public void close() { + if (cancelled) { + return; + } + cancelled = true; + resolver.shutdown(); + } + + private class NameResolverListener extends NameResolver.Listener2 { + @Override + public void onResult(final NameResolver.ResolutionResult resolutionResult) { + syncContext.execute(() -> onResult2(resolutionResult)); + } + + @Override + public Status onResult2(final NameResolver.ResolutionResult resolutionResult) { + if (cancelled) { + return Status.OK; + } + data = resolutionResult.getAddressesOrError(); + maybePublishConfig(); + return resolutionResult.getAddressesOrError().getStatus(); + } + + @Override + public void onError(final Status error) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (cancelled) { + return; + } + // DnsNameResolver cannot distinguish between address-not-found and transient errors. + // Assume it is a transient error. + // TODO: Once the resolution note API is available, don't throw away the error if + // hasDataValue(); pass it as the note instead + if (!hasDataValue()) { + data = StatusOr.fromStatus(error); + maybePublishConfig(); + } + } + }); + } + } + } + + private static final class FailingNameResolver extends NameResolver { + private final Status status; + + public FailingNameResolver(Status status) { + checkNotNull(status, "status"); + checkArgument(!status.isOk(), "Status must not be OK"); + this.status = status; + } + + @Override + public void start(Listener2 listener) { + listener.onError(status); + } + + @Override + public String getServiceAuthority() { + return "bug-if-you-see-this-authority"; + } + + @Override + public void shutdown() {} + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java b/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java index 3ed68ac9b75..9ad75595ea6 100644 --- a/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java @@ -20,9 +20,14 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableMap; +import com.google.common.net.InetAddresses; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import io.envoyproxy.envoy.config.core.v3.Address; import io.envoyproxy.envoy.config.core.v3.HealthStatus; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; import io.envoyproxy.envoy.type.v3.FractionalPercent; @@ -30,10 +35,12 @@ import io.grpc.internal.GrpcUtil; import io.grpc.xds.Endpoints.DropOverload; import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.MetadataRegistry.MetadataValueParser; import io.grpc.xds.XdsEndpointResource.EdsUpdate; import io.grpc.xds.client.Locality; import io.grpc.xds.client.XdsClient.ResourceUpdate; import io.grpc.xds.client.XdsResourceType; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collections; @@ -185,7 +192,8 @@ private static int getRatePerMillion(FractionalPercent percent) { @VisibleForTesting @Nullable static StructOrError parseLocalityLbEndpoints( - io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto) { + io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto) + throws ResourceInvalidException { // Filter out localities without or with 0 weight. if (!proto.hasLoadBalancingWeight() || proto.getLoadBalancingWeight().getValue() < 1) { return null; @@ -193,6 +201,15 @@ static StructOrError parseLocalityLbEndpoints( if (proto.getPriority() < 0) { return StructOrError.fromError("negative priority"); } + + ImmutableMap localityMetadata; + MetadataRegistry registry = MetadataRegistry.getInstance(); + try { + localityMetadata = registry.parseMetadata(proto.getMetadata()); + } catch (ResourceInvalidException e) { + throw new ResourceInvalidException("Failed to parse Locality Endpoint metadata: " + + e.getMessage(), e); + } List endpoints = new ArrayList<>(proto.getLbEndpointsCount()); for (io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint endpoint : proto.getLbEndpointsList()) { // The endpoint field of each lb_endpoints must be set. @@ -200,6 +217,13 @@ static StructOrError parseLocalityLbEndpoints( if (!endpoint.hasEndpoint() || !endpoint.getEndpoint().hasAddress()) { return StructOrError.fromError("LbEndpoint with no endpoint/address"); } + ImmutableMap endpointMetadata; + try { + endpointMetadata = registry.parseMetadata(endpoint.getMetadata()); + } catch (ResourceInvalidException e) { + throw new ResourceInvalidException("Failed to parse Endpoint metadata: " + + e.getMessage(), e); + } List addresses = new ArrayList<>(); addresses.add(getInetSocketAddress(endpoint.getEndpoint().getAddress())); @@ -213,16 +237,25 @@ static StructOrError parseLocalityLbEndpoints( || (endpoint.getHealthStatus() == HealthStatus.UNKNOWN); endpoints.add(Endpoints.LbEndpoint.create( new EquivalentAddressGroup(addresses), - endpoint.getLoadBalancingWeight().getValue(), isHealthy)); + endpoint.getLoadBalancingWeight().getValue(), isHealthy, + endpoint.getEndpoint().getHostname(), + endpointMetadata)); } return StructOrError.fromStruct(Endpoints.LocalityLbEndpoints.create( - endpoints, proto.getLoadBalancingWeight().getValue(), proto.getPriority())); + endpoints, proto.getLoadBalancingWeight().getValue(), + proto.getPriority(), localityMetadata)); } - private static InetSocketAddress getInetSocketAddress(Address address) { + private static InetSocketAddress getInetSocketAddress(Address address) + throws ResourceInvalidException { io.envoyproxy.envoy.config.core.v3.SocketAddress socketAddress = address.getSocketAddress(); - - return new InetSocketAddress(socketAddress.getAddress(), socketAddress.getPortValue()); + InetAddress parsedAddress; + try { + parsedAddress = InetAddresses.forString(socketAddress.getAddress()); + } catch (IllegalArgumentException ex) { + throw new ResourceInvalidException("Address is not an IP", ex); + } + return new InetSocketAddress(parsedAddress, socketAddress.getPortValue()); } static final class EdsUpdate implements ResourceUpdate { @@ -269,4 +302,47 @@ public String toString() { .toString(); } } + + public static class AddressMetadataParser implements MetadataValueParser { + + @Override + public String getTypeUrl() { + return "type.googleapis.com/envoy.config.core.v3.Address"; + } + + @Override + public java.net.SocketAddress parse(Any any) throws ResourceInvalidException { + SocketAddress socketAddress; + try { + socketAddress = any.unpack(Address.class).getSocketAddress(); + } catch (InvalidProtocolBufferException ex) { + throw new ResourceInvalidException("Invalid Resource in address proto", ex); + } + validateAddress(socketAddress); + + String ip = socketAddress.getAddress(); + int port = socketAddress.getPortValue(); + + try { + return new InetSocketAddress(InetAddresses.forString(ip), port); + } catch (IllegalArgumentException e) { + throw createException("Invalid IP address or port: " + ip + ":" + port); + } + } + + private void validateAddress(SocketAddress socketAddress) throws ResourceInvalidException { + if (socketAddress.getAddress().isEmpty()) { + throw createException("Address field is empty or invalid."); + } + long port = Integer.toUnsignedLong(socketAddress.getPortValue()); + if (port > 65535) { + throw createException(String.format("Port value %d out of range 1-65535.", port)); + } + } + + private ResourceInvalidException createException(String message) { + return new ResourceInvalidException( + "Failed to parse envoy.config.core.v3.Address: " + message); + } + } } diff --git a/xds/src/main/java/io/grpc/xds/XdsLbPolicies.java b/xds/src/main/java/io/grpc/xds/XdsLbPolicies.java index dcca2fbfff3..ae5ac38b471 100644 --- a/xds/src/main/java/io/grpc/xds/XdsLbPolicies.java +++ b/xds/src/main/java/io/grpc/xds/XdsLbPolicies.java @@ -19,7 +19,6 @@ final class XdsLbPolicies { static final String CLUSTER_MANAGER_POLICY_NAME = "cluster_manager_experimental"; static final String CDS_POLICY_NAME = "cds_experimental"; - static final String CLUSTER_RESOLVER_POLICY_NAME = "cluster_resolver_experimental"; static final String PRIORITY_POLICY_NAME = "priority_experimental"; static final String CLUSTER_IMPL_POLICY_NAME = "cluster_impl_experimental"; static final String WEIGHTED_TARGET_POLICY_NAME = "weighted_target_experimental"; diff --git a/xds/src/main/java/io/grpc/xds/XdsListenerResource.java b/xds/src/main/java/io/grpc/xds/XdsListenerResource.java index af77d128ae7..041b659b4c3 100644 --- a/xds/src/main/java/io/grpc/xds/XdsListenerResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsListenerResource.java @@ -25,6 +25,7 @@ import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.net.InetAddresses; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; @@ -43,7 +44,6 @@ import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.client.XdsResourceType; -import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; @@ -108,13 +108,13 @@ protected LdsUpdate doParse(Args args, Message unpackedMessage) Listener listener = (Listener) unpackedMessage; if (listener.hasApiListener()) { - return processClientSideListener(listener); + return processClientSideListener(listener, args); } else { return processServerSideListener(listener, args); } } - private LdsUpdate processClientSideListener(Listener listener) + private LdsUpdate processClientSideListener(Listener listener, XdsResourceType.Args args) throws ResourceInvalidException { // Unpack HttpConnectionManager from the Listener. HttpConnectionManager hcm; @@ -127,10 +127,10 @@ private LdsUpdate processClientSideListener(Listener listener) "Could not parse HttpConnectionManager config from ApiListener", e); } return LdsUpdate.forApiListener( - parseHttpConnectionManager(hcm, filterRegistry, true /* isForClient */)); + parseHttpConnectionManager(hcm, filterRegistry, true /* isForClient */, args)); } - private LdsUpdate processServerSideListener(Listener proto, Args args) + private LdsUpdate processServerSideListener(Listener proto, XdsResourceType.Args args) throws ResourceInvalidException { Set certProviderInstances = null; if (args.getBootstrapInfo() != null && args.getBootstrapInfo().certProviders() != null) { @@ -138,19 +138,19 @@ private LdsUpdate processServerSideListener(Listener proto, Args args) } return LdsUpdate.forTcpListener(parseServerSideListener(proto, (TlsContextManager) args.getSecurityConfig(), - filterRegistry, certProviderInstances)); + filterRegistry, certProviderInstances, args)); } @VisibleForTesting static EnvoyServerProtoData.Listener parseServerSideListener( Listener proto, TlsContextManager tlsContextManager, - FilterRegistry filterRegistry, Set certProviderInstances) + FilterRegistry filterRegistry, Set certProviderInstances, XdsResourceType.Args args) throws ResourceInvalidException { - if (!proto.getTrafficDirection().equals(TrafficDirection.INBOUND) - && !proto.getTrafficDirection().equals(TrafficDirection.UNSPECIFIED)) { + TrafficDirection trafficDirection = proto.getTrafficDirection(); + if (!trafficDirection.equals(TrafficDirection.INBOUND) + && !trafficDirection.equals(TrafficDirection.UNSPECIFIED)) { throw new ResourceInvalidException( - "Listener " + proto.getName() + " with invalid traffic direction: " - + proto.getTrafficDirection()); + "Listener " + proto.getName() + " with invalid traffic direction: " + trafficDirection); } if (!proto.getListenerFiltersList().isEmpty()) { throw new ResourceInvalidException( @@ -162,13 +162,16 @@ static EnvoyServerProtoData.Listener parseServerSideListener( } String address = null; + SocketAddress socketAddress = null; if (proto.getAddress().hasSocketAddress()) { - SocketAddress socketAddress = proto.getAddress().getSocketAddress(); + socketAddress = proto.getAddress().getSocketAddress(); address = socketAddress.getAddress(); + if (address.isEmpty()) { + throw new ResourceInvalidException("Invalid address: Empty address is not allowed."); + } switch (socketAddress.getPortSpecifierCase()) { case NAMED_PORT: - address = address + ":" + socketAddress.getNamedPort(); - break; + throw new ResourceInvalidException("NAMED_PORT is not supported in gRPC."); case PORT_VALUE: address = address + ":" + socketAddress.getPortValue(); break; @@ -178,56 +181,82 @@ static EnvoyServerProtoData.Listener parseServerSideListener( } ImmutableList.Builder filterChains = ImmutableList.builder(); - Set uniqueSet = new HashSet<>(); + Set filterChainNames = new HashSet<>(); + Set filterChainMatchSet = new HashSet<>(); + int i = 0; for (io.envoyproxy.envoy.config.listener.v3.FilterChain fc : proto.getFilterChainsList()) { + // May be empty. If it's not empty, required to be unique. + String filterChainName = fc.getName(); + if (filterChainName.isEmpty()) { + // Generate a name, so we can identify it in the logs. + filterChainName = "chain_" + i; + } + if (!filterChainNames.add(filterChainName)) { + throw new ResourceInvalidException("Filter chain names must be unique. " + + "Found duplicate: " + filterChainName); + } filterChains.add( - parseFilterChain(fc, tlsContextManager, filterRegistry, uniqueSet, - certProviderInstances)); + parseFilterChain(fc, filterChainName, tlsContextManager, filterRegistry, + filterChainMatchSet, certProviderInstances, args)); + i++; } + FilterChain defaultFilterChain = null; if (proto.hasDefaultFilterChain()) { + String defaultFilterChainName = proto.getDefaultFilterChain().getName(); + if (defaultFilterChainName.isEmpty()) { + defaultFilterChainName = "chain_default"; + } defaultFilterChain = parseFilterChain( - proto.getDefaultFilterChain(), tlsContextManager, filterRegistry, - null, certProviderInstances); + proto.getDefaultFilterChain(), defaultFilterChainName, tlsContextManager, filterRegistry, + null, certProviderInstances, args); } - return EnvoyServerProtoData.Listener.create( - proto.getName(), address, filterChains.build(), defaultFilterChain); + return EnvoyServerProtoData.Listener.create(proto.getName(), address, filterChains.build(), + defaultFilterChain, socketAddress == null ? null : socketAddress.getProtocol()); } @VisibleForTesting static FilterChain parseFilterChain( io.envoyproxy.envoy.config.listener.v3.FilterChain proto, - TlsContextManager tlsContextManager, FilterRegistry filterRegistry, - Set uniqueSet, Set certProviderInstances) + String filterChainName, + TlsContextManager tlsContextManager, + FilterRegistry filterRegistry, + // null disables FilterChainMatch uniqueness check, used for defaultFilterChain + @Nullable Set filterChainMatchSet, + Set certProviderInstances, + XdsResourceType.Args args) throws ResourceInvalidException { + // FilterChain contains L4 filters, so we ensure it contains only HCM. if (proto.getFiltersCount() != 1) { - throw new ResourceInvalidException("FilterChain " + proto.getName() + throw new ResourceInvalidException("FilterChain " + filterChainName + " should contain exact one HttpConnectionManager filter"); } - io.envoyproxy.envoy.config.listener.v3.Filter filter = proto.getFiltersList().get(0); - if (!filter.hasTypedConfig()) { + io.envoyproxy.envoy.config.listener.v3.Filter l4Filter = proto.getFiltersList().get(0); + if (!l4Filter.hasTypedConfig()) { throw new ResourceInvalidException( - "FilterChain " + proto.getName() + " contains filter " + filter.getName() + "FilterChain " + filterChainName + " contains filter " + l4Filter.getName() + " without typed_config"); } - Any any = filter.getTypedConfig(); - // HttpConnectionManager is the only supported network filter at the moment. + Any any = l4Filter.getTypedConfig(); if (!any.getTypeUrl().equals(TYPE_URL_HTTP_CONNECTION_MANAGER)) { throw new ResourceInvalidException( - "FilterChain " + proto.getName() + " contains filter " + filter.getName() + "FilterChain " + filterChainName + " contains filter " + l4Filter.getName() + " with unsupported typed_config type " + any.getTypeUrl()); } + + // Parse HCM. HttpConnectionManager hcmProto; try { hcmProto = any.unpack(HttpConnectionManager.class); } catch (InvalidProtocolBufferException e) { - throw new ResourceInvalidException("FilterChain " + proto.getName() + " with filter " - + filter.getName() + " failed to unpack message", e); + throw new ResourceInvalidException("FilterChain " + filterChainName + " with filter " + + l4Filter.getName() + " failed to unpack message", e); } io.grpc.xds.HttpConnectionManager httpConnectionManager = parseHttpConnectionManager( - hcmProto, filterRegistry, false /* isForClient */); + hcmProto, filterRegistry, false /* isForClient */, args); + // Parse Transport Socket. EnvoyServerProtoData.DownstreamTlsContext downstreamTlsContext = null; if (proto.hasTransportSocket()) { if (!TRANSPORT_SOCKET_NAME_TLS.equals(proto.getTransportSocket().getName())) { @@ -239,7 +268,7 @@ static FilterChain parseFilterChain( downstreamTlsContextProto = proto.getTransportSocket().getTypedConfig().unpack(DownstreamTlsContext.class); } catch (InvalidProtocolBufferException e) { - throw new ResourceInvalidException("FilterChain " + proto.getName() + throw new ResourceInvalidException("FilterChain " + filterChainName + " failed to unpack message", e); } downstreamTlsContext = @@ -247,10 +276,15 @@ static FilterChain parseFilterChain( validateDownstreamTlsContext(downstreamTlsContextProto, certProviderInstances)); } + // Parse FilterChainMatch. FilterChainMatch filterChainMatch = parseFilterChainMatch(proto.getFilterChainMatch()); - checkForUniqueness(uniqueSet, filterChainMatch); + // null used to skip this check for defaultFilterChain. + if (filterChainMatchSet != null) { + validateFilterChainMatchForUniqueness(filterChainMatchSet, filterChainMatch); + } + return FilterChain.create( - proto.getName(), + filterChainName, filterChainMatch, httpConnectionManager, downstreamTlsContext, @@ -284,15 +318,15 @@ static DownstreamTlsContext validateDownstreamTlsContext( return downstreamTlsContext; } - private static void checkForUniqueness(Set uniqueSet, + private static void validateFilterChainMatchForUniqueness( + Set filterChainMatchSet, FilterChainMatch filterChainMatch) throws ResourceInvalidException { - if (uniqueSet != null) { - List crossProduct = getCrossProduct(filterChainMatch); - for (FilterChainMatch cur : crossProduct) { - if (!uniqueSet.add(cur)) { - throw new ResourceInvalidException("FilterChainMatch must be unique. " - + "Found duplicate: " + cur); - } + // Flattens complex FilterChainMatch into a list of simple FilterChainMatch'es. + List crossProduct = getCrossProduct(filterChainMatch); + for (FilterChainMatch cur : crossProduct) { + if (!filterChainMatchSet.add(cur)) { + throw new ResourceInvalidException("FilterChainMatch must be unique. " + + "Found duplicate: " + cur); } } } @@ -420,16 +454,18 @@ private static FilterChainMatch parseFilterChainMatch( try { for (io.envoyproxy.envoy.config.core.v3.CidrRange range : proto.getPrefixRangesList()) { prefixRanges.add( - CidrRange.create(range.getAddressPrefix(), range.getPrefixLen().getValue())); + CidrRange.create(InetAddresses.forString(range.getAddressPrefix()), + range.getPrefixLen().getValue())); } for (io.envoyproxy.envoy.config.core.v3.CidrRange range : proto.getSourcePrefixRangesList()) { - sourcePrefixRanges.add( - CidrRange.create(range.getAddressPrefix(), range.getPrefixLen().getValue())); + sourcePrefixRanges.add(CidrRange.create( + InetAddresses.forString(range.getAddressPrefix()), range.getPrefixLen().getValue())); } - } catch (UnknownHostException e) { - throw new ResourceInvalidException("Failed to create CidrRange", e); + } catch (IllegalArgumentException ex) { + throw new ResourceInvalidException("Failed to create CidrRange", ex); } + ConnectionSourceType sourceType; switch (proto.getSourceType()) { case ANY: @@ -458,7 +494,7 @@ private static FilterChainMatch parseFilterChainMatch( @VisibleForTesting static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( HttpConnectionManager proto, FilterRegistry filterRegistry, - boolean isForClient) throws ResourceInvalidException { + boolean isForClient, XdsResourceType.Args args) throws ResourceInvalidException { if (proto.getXffNumTrustedHops() != 0) { throw new ResourceInvalidException( "HttpConnectionManager with xff_num_trusted_hops unsupported"); @@ -515,7 +551,7 @@ static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( // Parse inlined RouteConfiguration or RDS. if (proto.hasRouteConfig()) { List virtualHosts = extractVirtualHosts( - proto.getRouteConfig(), filterRegistry); + proto.getRouteConfig(), filterRegistry, args); return io.grpc.xds.HttpConnectionManager.forVirtualHosts( maxStreamDuration, virtualHosts, filterConfigs); } @@ -549,12 +585,8 @@ static StructOrError parseHttpFilter( String filterName = httpFilter.getName(); boolean isOptional = httpFilter.getIsOptional(); if (!httpFilter.hasTypedConfig()) { - if (isOptional) { - return null; - } else { - return StructOrError.fromError( - "HttpFilter [" + filterName + "] is not optional and has no typed config"); - } + return isOptional ? null : StructOrError.fromError( + "HttpFilter [" + filterName + "] is not optional and has no typed config"); } Message rawConfig = httpFilter.getTypedConfig(); String typeUrl = httpFilter.getTypedConfig().getTypeUrl(); @@ -574,18 +606,17 @@ static StructOrError parseHttpFilter( return StructOrError.fromError( "HttpFilter [" + filterName + "] contains invalid proto: " + e); } - Filter filter = filterRegistry.get(typeUrl); - if ((isForClient && !(filter instanceof Filter.ClientInterceptorBuilder)) - || (!isForClient && !(filter instanceof Filter.ServerInterceptorBuilder))) { - if (isOptional) { - return null; - } else { - return StructOrError.fromError( - "HttpFilter [" + filterName + "](" + typeUrl + ") is required but unsupported for " - + (isForClient ? "client" : "server")); - } + + Filter.Provider provider = filterRegistry.get(typeUrl); + if (provider == null + || (isForClient && !provider.isClientFilter()) + || (!isForClient && !provider.isServerFilter())) { + // Filter type not supported. + return isOptional ? null : StructOrError.fromError( + "HttpFilter [" + filterName + "](" + typeUrl + ") is required but unsupported for " + ( + isForClient ? "client" : "server")); } - ConfigOrError filterConfig = filter.parseFilterConfig(rawConfig); + ConfigOrError filterConfig = provider.parseFilterConfig(rawConfig); if (filterConfig.errorDetail != null) { return StructOrError.fromError( "Invalid filter config for HttpFilter [" + filterName + "]: " + filterConfig.errorDetail); diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index f0329387fc9..69b0b824433 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -41,14 +41,15 @@ import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; import io.grpc.xds.ClusterSpecifierPlugin.PluginConfig; -import io.grpc.xds.Filter.ClientInterceptorBuilder; import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; import io.grpc.xds.RouteLookupServiceClusterSpecifierPlugin.RlsPluginConfig; @@ -58,15 +59,14 @@ import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight; import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy; import io.grpc.xds.VirtualHost.Route.RouteAction.RetryPolicy; +import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; -import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.client.Bootstrapper.AuthorityInfo; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsClient.ResourceWatcher; +import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.client.XdsLogger; import io.grpc.xds.client.XdsLogger.XdsLogLevel; -import java.net.URI; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -80,6 +80,7 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import javax.annotation.Nullable; /** @@ -91,11 +92,14 @@ * @see XdsNameResolverProvider */ final class XdsNameResolver extends NameResolver { - static final CallOptions.Key CLUSTER_SELECTION_KEY = CallOptions.Key.create("io.grpc.xds.CLUSTER_SELECTION_KEY"); + static final CallOptions.Key XDS_CONFIG_CALL_OPTION_KEY = + CallOptions.Key.create("io.grpc.xds.XDS_CONFIG_CALL_OPTION_KEY"); static final CallOptions.Key RPC_HASH_KEY = CallOptions.Key.create("io.grpc.xds.RPC_HASH_KEY"); + static final CallOptions.Key AUTO_HOST_REWRITE_KEY = + CallOptions.Key.create("io.grpc.xds.AUTO_HOST_REWRITE_KEY"); @VisibleForTesting static boolean enableTimeout = Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT")) @@ -105,7 +109,6 @@ final class XdsNameResolver extends NameResolver { private final XdsLogger logger; @Nullable private final String targetAuthority; - private final String target; private final String serviceAuthority; // Encoded version of the service authority as per // https://datatracker.ietf.org/doc/html/rfc3986#section-3.2. @@ -114,7 +117,7 @@ final class XdsNameResolver extends NameResolver { private final ServiceConfigParser serviceConfigParser; private final SynchronizationContext syncContext; private final ScheduledExecutorService scheduler; - private final XdsClientPoolFactory xdsClientPoolFactory; + private final XdsClientPool xdsClientPool; private final ThreadSafeRandom random; private final FilterRegistry filterRegistry; private final XxHash64 hashFunc = XxHash64.INSTANCE; @@ -123,36 +126,49 @@ final class XdsNameResolver extends NameResolver { private final ConcurrentMap clusterRefs = new ConcurrentHashMap<>(); private final ConfigSelector configSelector = new ConfigSelector(); private final long randomChannelId; + private final Args nameResolverArgs; + // Must be accessed in syncContext. + // Filter instances are unique per channel, and per filter (name+typeUrl). + // NamedFilterConfig.filterStateKey -> filter_instance. + private final HashMap activeFilters = new HashMap<>(); - private volatile RoutingConfig routingConfig = RoutingConfig.empty; + private volatile RoutingConfig routingConfig; private Listener2 listener; - private ObjectPool xdsClientPool; private XdsClient xdsClient; private CallCounterProvider callCounterProvider; private ResolveState resolveState; - // Workaround for https://github.com/grpc/grpc-java/issues/8886 . This should be handled in - // XdsClient instead of here. - private boolean receivedConfig; + /** + * Constructs a new instance. + * + * @param target the target URI to resolve + * @param targetAuthority the authority component of `target`, possibly the empty string, or null + * if 'target' has no such component + */ XdsNameResolver( - URI targetUri, String name, @Nullable String overrideAuthority, - ServiceConfigParser serviceConfigParser, + String target, @Nullable String targetAuthority, String name, + @Nullable String overrideAuthority, ServiceConfigParser serviceConfigParser, SynchronizationContext syncContext, ScheduledExecutorService scheduler, - @Nullable Map bootstrapOverride) { - this(targetUri, targetUri.getAuthority(), name, overrideAuthority, serviceConfigParser, - syncContext, scheduler, SharedXdsClientPoolProvider.getDefaultProvider(), - ThreadSafeRandomImpl.instance, FilterRegistry.getDefaultRegistry(), bootstrapOverride); + @Nullable Map bootstrapOverride, + MetricRecorder metricRecorder, Args nameResolverArgs) { + this(target, targetAuthority, name, overrideAuthority, serviceConfigParser, + syncContext, scheduler, + bootstrapOverride == null + ? SharedXdsClientPoolProvider.getDefaultProvider() + : new SharedXdsClientPoolProvider(), + ThreadSafeRandomImpl.instance, FilterRegistry.getDefaultRegistry(), bootstrapOverride, + metricRecorder, nameResolverArgs); } @VisibleForTesting XdsNameResolver( - URI targetUri, @Nullable String targetAuthority, String name, + String target, @Nullable String targetAuthority, String name, @Nullable String overrideAuthority, ServiceConfigParser serviceConfigParser, SynchronizationContext syncContext, ScheduledExecutorService scheduler, XdsClientPoolFactory xdsClientPoolFactory, ThreadSafeRandom random, - FilterRegistry filterRegistry, @Nullable Map bootstrapOverride) { + FilterRegistry filterRegistry, @Nullable Map bootstrapOverride, + MetricRecorder metricRecorder, Args nameResolverArgs) { this.targetAuthority = targetAuthority; - target = targetUri.toString(); // The name might have multiple slashes so encode it before verifying. serviceAuthority = checkNotNull(name, "name"); @@ -163,11 +179,19 @@ final class XdsNameResolver extends NameResolver { this.serviceConfigParser = checkNotNull(serviceConfigParser, "serviceConfigParser"); this.syncContext = checkNotNull(syncContext, "syncContext"); this.scheduler = checkNotNull(scheduler, "scheduler"); - this.xdsClientPoolFactory = bootstrapOverride == null ? checkNotNull(xdsClientPoolFactory, - "xdsClientPoolFactory") : new SharedXdsClientPoolProvider(); - this.xdsClientPoolFactory.setBootstrapOverride(bootstrapOverride); + Supplier xdsClientSupplierArg = + nameResolverArgs.getArg(XdsNameResolverProvider.XDS_CLIENT_SUPPLIER); + if (xdsClientSupplierArg != null) { + this.xdsClientPool = new SupplierXdsClientPool(xdsClientSupplierArg); + } else { + checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + this.xdsClientPool = new BootstrappingXdsClientPool( + xdsClientPoolFactory, target, bootstrapOverride, metricRecorder); + } this.random = checkNotNull(random, "random"); this.filterRegistry = checkNotNull(filterRegistry, "filterRegistry"); + this.nameResolverArgs = checkNotNull(nameResolverArgs, "nameResolverArgs"); + randomChannelId = random.nextLong(); logId = InternalLogId.allocate("xds-resolver", name); logger = XdsLogger.withLogId(logId); @@ -183,16 +207,17 @@ public String getServiceAuthority() { public void start(Listener2 listener) { this.listener = checkNotNull(listener, "listener"); try { - xdsClientPool = xdsClientPoolFactory.getOrCreate(target); + xdsClient = xdsClientPool.getObject(); } catch (Exception e) { listener.onError( Status.UNAVAILABLE.withDescription("Failed to initialize xDS").withCause(e)); return; } - xdsClient = xdsClientPool.getObject(); BootstrapInfo bootstrapInfo = xdsClient.getBootstrapInfo(); String listenerNameTemplate; - if (targetAuthority == null) { + if (targetAuthority == null || targetAuthority.isEmpty()) { + // Both https://github.com/grpc/proposal/blob/master/A27-xds-global-load-balancing.md and + // A47-xds-federation.md seem to treat an empty authority the same as an undefined one. listenerNameTemplate = bootstrapInfo.clientDefaultListenerResourceNameTemplate(); } else { AuthorityInfo authorityInfo = bootstrapInfo.authorities().get(targetAuthority); @@ -216,10 +241,18 @@ public void start(Listener2 listener) { } ldsResourceName = XdsClient.canonifyResourceName(ldsResourceName); callCounterProvider = SharedCallCounterMap.getInstance(); + resolveState = new ResolveState(ldsResourceName); resolveState.start(); } + @Override + public void refresh() { + if (resolveState != null) { + resolveState.refresh(); + } + } + private static String expandPercentS(String template, String replacement) { return template.replace("%s", replacement); } @@ -228,7 +261,7 @@ private static String expandPercentS(String template, String replacement) { public void shutdown() { logger.log(XdsLogLevel.INFO, "Shutdown"); if (resolveState != null) { - resolveState.stop(); + resolveState.shutdown(); } if (xdsClient != null) { xdsClient = xdsClientPool.returnObject(xdsClient); @@ -277,7 +310,7 @@ XdsClient getXdsClient() { } // called in syncContext - private void updateResolutionResult() { + private void updateResolutionResult(XdsConfig xdsConfig) { syncContext.throwIfNotInThisSynchronizationContext(); ImmutableMap.Builder childPolicy = new ImmutableMap.Builder<>(); @@ -293,13 +326,15 @@ private void updateResolutionResult() { if (logger.isLoggable(XdsLogLevel.INFO)) { logger.log( - XdsLogLevel.INFO, "Generated service config:\n{0}", new Gson().toJson(rawServiceConfig)); + XdsLogLevel.INFO, "Generated service config: {0}", new Gson().toJson(rawServiceConfig)); } ConfigOrError parsedServiceConfig = serviceConfigParser.parseServiceConfig(rawServiceConfig); Attributes attrs = Attributes.newBuilder() - .set(InternalXdsAttributes.XDS_CLIENT_POOL, xdsClientPool) - .set(InternalXdsAttributes.CALL_COUNTER_PROVIDER, callCounterProvider) + .set(XdsAttributes.XDS_CLIENT, xdsClient) + .set(XdsAttributes.XDS_CONFIG, xdsConfig) + .set(XdsAttributes.XDS_CLUSTER_SUBSCRIPT_REGISTRY, resolveState.xdsDependencyManager) + .set(XdsAttributes.CALL_COUNTER_PROVIDER, callCounterProvider) .set(InternalConfigSelector.KEY, configSelector) .build(); ResolutionResult result = @@ -307,8 +342,9 @@ private void updateResolutionResult() { .setAttributes(attrs) .setServiceConfig(parsedServiceConfig) .build(); - listener.onResult(result); - receivedConfig = true; + if (!listener.onResult2(result).isOk()) { + resolveState.xdsDependencyManager.requestReresolution(); + } } /** @@ -374,21 +410,21 @@ static boolean matchHostName(String hostName, String pattern) { private final class ConfigSelector extends InternalConfigSelector { @Override public Result selectConfig(PickSubchannelArgs args) { - String cluster = null; - Route selectedRoute = null; RoutingConfig routingCfg; - Map selectedOverrideConfigs; - List filterInterceptors = new ArrayList<>(); + RouteData selectedRoute; + String cluster; + ClientInterceptor filters; Metadata headers = args.getHeaders(); + String path = "/" + args.getMethodDescriptor().getFullMethodName(); do { routingCfg = routingConfig; - selectedOverrideConfigs = new HashMap<>(routingCfg.virtualHostOverrideConfig); - for (Route route : routingCfg.routes) { - if (RoutingUtils.matchRoute( - route.routeMatch(), "/" + args.getMethodDescriptor().getFullMethodName(), - headers, random)) { + if (routingCfg.errorStatus != null) { + return Result.forError(routingCfg.errorStatus); + } + selectedRoute = null; + for (RouteData route : routingCfg.routes) { + if (RoutingUtils.matchRoute(route.routeMatch, path, headers, random)) { selectedRoute = route; - selectedOverrideConfigs.putAll(route.filterConfigOverrides()); break; } } @@ -396,38 +432,45 @@ public Result selectConfig(PickSubchannelArgs args) { return Result.forError( Status.UNAVAILABLE.withDescription("Could not find xDS route matching RPC")); } - if (selectedRoute.routeAction() == null) { + if (selectedRoute.routeAction == null) { return Result.forError(Status.UNAVAILABLE.withDescription( "Could not route RPC to Route with non-forwarding action")); } - RouteAction action = selectedRoute.routeAction(); + RouteAction action = selectedRoute.routeAction; if (action.cluster() != null) { cluster = prefixedClusterName(action.cluster()); + filters = selectedRoute.filterChoices.get(0); } else if (action.weightedClusters() != null) { + // XdsRouteConfigureResource verifies the total weight will not be 0 or exceed uint32 long totalWeight = 0; for (ClusterWeight weightedCluster : action.weightedClusters()) { totalWeight += weightedCluster.weight(); } long select = random.nextLong(totalWeight); long accumulator = 0; - for (ClusterWeight weightedCluster : action.weightedClusters()) { + for (int i = 0; ; i++) { + ClusterWeight weightedCluster = action.weightedClusters().get(i); accumulator += weightedCluster.weight(); if (select < accumulator) { cluster = prefixedClusterName(weightedCluster.name()); - selectedOverrideConfigs.putAll(weightedCluster.filterConfigOverrides()); + filters = selectedRoute.filterChoices.get(i); break; } } } else if (action.namedClusterSpecifierPluginConfig() != null) { cluster = prefixedClusterSpecifierPluginName(action.namedClusterSpecifierPluginConfig().name()); + filters = selectedRoute.filterChoices.get(0); + } else { + // updateRoutes() discards routes with unknown actions + throw new AssertionError(); } } while (!retainCluster(cluster)); + + final RouteAction routeAction = selectedRoute.routeAction; Long timeoutNanos = null; if (enableTimeout) { - if (selectedRoute != null) { - timeoutNanos = selectedRoute.routeAction().timeoutNano(); - } + timeoutNanos = routeAction.timeoutNano(); if (timeoutNanos == null) { timeoutNanos = routingCfg.fallbackTimeoutNano; } @@ -435,8 +478,7 @@ public Result selectConfig(PickSubchannelArgs args) { timeoutNanos = null; } } - RetryPolicy retryPolicy = - selectedRoute == null ? null : selectedRoute.routeAction().retryPolicy(); + RetryPolicy retryPolicy = routeAction.retryPolicy(); // TODO(chengyuanzhang): avoid service config generation and parsing for each call. Map rawServiceConfig = generateServiceConfigWithMethodConfig(timeoutNanos, retryPolicy); @@ -448,31 +490,21 @@ public Result selectConfig(PickSubchannelArgs args) { parsedServiceConfig.getError().augmentDescription( "Failed to parse service config (method config)")); } - if (routingCfg.filterChain != null) { - for (NamedFilterConfig namedFilter : routingCfg.filterChain) { - FilterConfig filterConfig = namedFilter.filterConfig; - Filter filter = filterRegistry.get(filterConfig.typeUrl()); - if (filter instanceof ClientInterceptorBuilder) { - ClientInterceptor interceptor = ((ClientInterceptorBuilder) filter) - .buildClientInterceptor( - filterConfig, selectedOverrideConfigs.get(namedFilter.name), - args, scheduler); - if (interceptor != null) { - filterInterceptors.add(interceptor); - } - } - } - } final String finalCluster = cluster; - final long hash = generateHash(selectedRoute.routeAction().hashPolicies(), headers); + final XdsConfig xdsConfig = routingCfg.xdsConfig; + final long hash = generateHash(routeAction.hashPolicies(), headers); class ClusterSelectionInterceptor implements ClientInterceptor { @Override public ClientCall interceptCall( final MethodDescriptor method, CallOptions callOptions, final Channel next) { - final CallOptions callOptionsForCluster = + CallOptions callOptionsForCluster = callOptions.withOption(CLUSTER_SELECTION_KEY, finalCluster) + .withOption(XDS_CONFIG_CALL_OPTION_KEY, xdsConfig) .withOption(RPC_HASH_KEY, hash); + if (routeAction.autoHostRewrite()) { + callOptionsForCluster = callOptionsForCluster.withOption(AUTO_HOST_REWRITE_KEY, true); + } return new SimpleForwardingClientCall( next.newCall(method, callOptionsForCluster)) { @Override @@ -501,11 +533,11 @@ public void onClose(Status status, Metadata trailers) { } } - filterInterceptors.add(new ClusterSelectionInterceptor()); return Result.newBuilder() .setConfig(config) - .setInterceptor(combineInterceptors(filterInterceptors)) + .setInterceptor(combineInterceptors( + ImmutableList.of(new ClusterSelectionInterceptor(), filters))) .build(); } @@ -527,13 +559,21 @@ private boolean retainCluster(String cluster) { private void releaseCluster(final String cluster) { int count = clusterRefs.get(cluster).refCount.decrementAndGet(); + if (count < 0) { + throw new AssertionError(); + } if (count == 0) { syncContext.execute(new Runnable() { @Override public void run() { - if (clusterRefs.get(cluster).refCount.get() == 0) { - clusterRefs.remove(cluster); - updateResolutionResult(); + if (clusterRefs.get(cluster).refCount.get() != 0) { + throw new AssertionError(); + } + clusterRefs.remove(cluster).close(); + if (resolveState.lastConfigOrStatus.hasValue()) { + updateResolutionResult(resolveState.lastConfigOrStatus.getValue()); + } else { + resolveState.cleanUpRoutes(resolveState.lastConfigOrStatus.getStatus()); } } }); @@ -571,8 +611,18 @@ private long generateHash(List hashPolicies, Metadata headers) { } } + static final class PassthroughClientInterceptor implements ClientInterceptor { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + } + private static ClientInterceptor combineInterceptors(final List interceptors) { - checkArgument(!interceptors.isEmpty(), "empty interceptors"); + if (interceptors.size() == 0) { + return new PassthroughClientInterceptor(); + } if (interceptors.size() == 1) { return interceptors.get(0); } @@ -612,103 +662,106 @@ private static String prefixedClusterSpecifierPluginName(String pluginName) { return "cluster_specifier_plugin:" + pluginName; } - private static final class FailingConfigSelector extends InternalConfigSelector { - private final Result result; - - public FailingConfigSelector(Status error) { - this.result = Result.forError(error); - } - - @Override - public Result selectConfig(PickSubchannelArgs args) { - return result; - } - } - - private class ResolveState implements ResourceWatcher { + class ResolveState implements XdsDependencyManager.XdsConfigWatcher { private final ConfigOrError emptyServiceConfig = serviceConfigParser.parseServiceConfig(Collections.emptyMap()); - private final String ldsResourceName; + private final String authority; + private final XdsDependencyManager xdsDependencyManager; private boolean stopped; @Nullable private Set existingClusters; // clusters to which new requests can be routed - @Nullable - private RouteDiscoveryState routeDiscoveryState; + private StatusOr lastConfigOrStatus; - ResolveState(String ldsResourceName) { - this.ldsResourceName = ldsResourceName; + private ResolveState(String ldsResourceName) { + authority = overrideAuthority != null ? overrideAuthority : encodedServiceAuthority; + xdsDependencyManager = + new XdsDependencyManager(xdsClient, syncContext, authority, ldsResourceName, + nameResolverArgs); } - @Override - public void onChanged(final XdsListenerResource.LdsUpdate update) { + void start() { + xdsDependencyManager.start(this); + } + + void refresh() { + xdsDependencyManager.requestReresolution(); + } + + private void shutdown() { if (stopped) { return; } - logger.log(XdsLogLevel.INFO, "Receive LDS resource update: {0}", update); - HttpConnectionManager httpConnectionManager = update.httpConnectionManager(); - List virtualHosts = httpConnectionManager.virtualHosts(); - String rdsName = httpConnectionManager.rdsName(); - cleanUpRouteDiscoveryState(); - if (virtualHosts != null) { - updateRoutes(virtualHosts, httpConnectionManager.httpMaxStreamDurationNano(), - httpConnectionManager.httpFilterConfigs()); - } else { - routeDiscoveryState = new RouteDiscoveryState( - rdsName, httpConnectionManager.httpMaxStreamDurationNano(), - httpConnectionManager.httpFilterConfigs()); - logger.log(XdsLogLevel.INFO, "Start watching RDS resource {0}", rdsName); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), - rdsName, routeDiscoveryState, syncContext); - } + + stopped = true; + xdsDependencyManager.shutdown(); + updateActiveFilters(null); } @Override - public void onError(final Status error) { - if (stopped || receivedConfig) { + public void onUpdate(StatusOr updateOrStatus) { + if (stopped) { return; } - listener.onError(Status.UNAVAILABLE.withCause(error.getCause()).withDescription( - String.format("Unable to load LDS %s. xDS server returned: %s: %s", - ldsResourceName, error.getCode(), error.getDescription()))); - } + logger.log(XdsLogLevel.INFO, "Receive XDS resource update: {0}", updateOrStatus); - @Override - public void onResourceDoesNotExist(final String resourceName) { - if (stopped) { + lastConfigOrStatus = updateOrStatus; + if (!updateOrStatus.hasValue()) { + updateActiveFilters(null); + cleanUpRoutes(updateOrStatus.getStatus()); return; } - String error = "LDS resource does not exist: " + resourceName; - logger.log(XdsLogLevel.INFO, error); - cleanUpRouteDiscoveryState(); - cleanUpRoutes(error); - } - private void start() { - logger.log(XdsLogLevel.INFO, "Start watching LDS resource {0}", ldsResourceName); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(), - ldsResourceName, this, syncContext); - } + // Process Route + XdsConfig update = updateOrStatus.getValue(); + HttpConnectionManager httpConnectionManager = update.getListener().httpConnectionManager(); + if (httpConnectionManager == null) { + logger.log(XdsLogLevel.INFO, "API Listener: httpConnectionManager does not exist."); + updateActiveFilters(null); + cleanUpRoutes(updateOrStatus.getStatus()); + return; + } - private void stop() { - logger.log(XdsLogLevel.INFO, "Stop watching LDS resource {0}", ldsResourceName); - stopped = true; - cleanUpRouteDiscoveryState(); - xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(), ldsResourceName, this); + VirtualHost virtualHost = update.getVirtualHost(); + ImmutableList filterConfigs = httpConnectionManager.httpFilterConfigs(); + long streamDurationNano = httpConnectionManager.httpMaxStreamDurationNano(); + + updateActiveFilters(filterConfigs); + updateRoutes(update, virtualHost, streamDurationNano, filterConfigs); } // called in syncContext - private void updateRoutes(List virtualHosts, long httpMaxStreamDurationNano, - @Nullable List filterConfigs) { - String authority = overrideAuthority != null ? overrideAuthority : encodedServiceAuthority; - VirtualHost virtualHost = RoutingUtils.findVirtualHostForHostName(virtualHosts, authority); - if (virtualHost == null) { - String error = "Failed to find virtual host matching hostname: " + authority; - logger.log(XdsLogLevel.WARNING, error); - cleanUpRoutes(error); - return; + private void updateActiveFilters(@Nullable List filterConfigs) { + if (filterConfigs == null) { + filterConfigs = ImmutableList.of(); + } + Set filtersToShutdown = new HashSet<>(activeFilters.keySet()); + for (NamedFilterConfig namedFilter : filterConfigs) { + String typeUrl = namedFilter.filterConfig.typeUrl(); + String filterKey = namedFilter.filterStateKey(); + + Filter.Provider provider = filterRegistry.get(typeUrl); + checkNotNull(provider, "provider %s", typeUrl); + Filter filter = activeFilters.computeIfAbsent( + filterKey, k -> provider.newInstance(namedFilter.name)); + checkNotNull(filter, "filter %s", filterKey); + filtersToShutdown.remove(filterKey); + } + + // Shutdown filters not present in current HCM. + for (String filterKey : filtersToShutdown) { + Filter filterToShutdown = activeFilters.remove(filterKey); + checkNotNull(filterToShutdown, "filterToShutdown %s", filterKey); + filterToShutdown.close(); } + } + private void updateRoutes( + XdsConfig xdsConfig, + @Nullable VirtualHost virtualHost, + long httpMaxStreamDurationNano, + @Nullable List filterConfigs) { List routes = virtualHost.routes(); + ImmutableList.Builder routesData = ImmutableList.builder(); // Populate all clusters to which requests can be routed to through the virtual host. Set clusters = new HashSet<>(); @@ -719,26 +772,36 @@ private void updateRoutes(List virtualHosts, long httpMaxStreamDura for (Route route : routes) { RouteAction action = route.routeAction(); String prefixedName; - if (action != null) { - if (action.cluster() != null) { - prefixedName = prefixedClusterName(action.cluster()); + if (action == null) { + routesData.add(new RouteData(route.routeMatch(), null, ImmutableList.of())); + } else if (action.cluster() != null) { + prefixedName = prefixedClusterName(action.cluster()); + clusters.add(prefixedName); + clusterNameMap.put(prefixedName, action.cluster()); + ClientInterceptor filters = createFilters(filterConfigs, virtualHost, route, null); + routesData.add(new RouteData(route.routeMatch(), route.routeAction(), filters)); + } else if (action.weightedClusters() != null) { + ImmutableList.Builder filterList = ImmutableList.builder(); + for (ClusterWeight weightedCluster : action.weightedClusters()) { + prefixedName = prefixedClusterName(weightedCluster.name()); clusters.add(prefixedName); - clusterNameMap.put(prefixedName, action.cluster()); - } else if (action.weightedClusters() != null) { - for (ClusterWeight weighedCluster : action.weightedClusters()) { - prefixedName = prefixedClusterName(weighedCluster.name()); - clusters.add(prefixedName); - clusterNameMap.put(prefixedName, weighedCluster.name()); - } - } else if (action.namedClusterSpecifierPluginConfig() != null) { - PluginConfig pluginConfig = action.namedClusterSpecifierPluginConfig().config(); - if (pluginConfig instanceof RlsPluginConfig) { - prefixedName = prefixedClusterSpecifierPluginName( - action.namedClusterSpecifierPluginConfig().name()); - clusters.add(prefixedName); - rlsPluginConfigMap.put(prefixedName, (RlsPluginConfig) pluginConfig); - } + clusterNameMap.put(prefixedName, weightedCluster.name()); + filterList.add(createFilters(filterConfigs, virtualHost, route, weightedCluster)); + } + routesData.add( + new RouteData(route.routeMatch(), route.routeAction(), filterList.build())); + } else if (action.namedClusterSpecifierPluginConfig() != null) { + PluginConfig pluginConfig = action.namedClusterSpecifierPluginConfig().config(); + if (pluginConfig instanceof RlsPluginConfig) { + prefixedName = prefixedClusterSpecifierPluginName( + action.namedClusterSpecifierPluginConfig().name()); + clusters.add(prefixedName); + rlsPluginConfigMap.put(prefixedName, (RlsPluginConfig) pluginConfig); } + ClientInterceptor filters = createFilters(filterConfigs, virtualHost, route, null); + routesData.add(new RouteData(route.routeMatch(), route.routeAction(), filters)); + } else { + // Discard route } } @@ -755,9 +818,13 @@ private void updateRoutes(List virtualHosts, long httpMaxStreamDura clusterRefs.get(cluster).refCount.incrementAndGet(); } else { if (clusterNameMap.containsKey(cluster)) { + assert cluster.startsWith("cluster:"); + XdsConfig.Subscription subscription = + xdsDependencyManager.subscribeToCluster(cluster.substring("cluster:".length())); clusterRefs.put( cluster, - ClusterRefState.forCluster(new AtomicInteger(1), clusterNameMap.get(cluster))); + ClusterRefState.forCluster( + new AtomicInteger(1), clusterNameMap.get(cluster), subscription)); } if (rlsPluginConfigMap.containsKey(cluster)) { clusterRefs.put( @@ -778,108 +845,86 @@ private void updateRoutes(List virtualHosts, long httpMaxStreamDura } } // Update service config to include newly added clusters. - if (shouldUpdateResult) { - updateResolutionResult(); + if (shouldUpdateResult && routingConfig != null) { + updateResolutionResult(xdsConfig); + shouldUpdateResult = false; + } else { + // Need to update at least once + shouldUpdateResult = true; } // Make newly added clusters selectable by config selector and deleted clusters no longer // selectable. - routingConfig = - new RoutingConfig( - httpMaxStreamDurationNano, routes, filterConfigs, - virtualHost.filterConfigOverrides()); - shouldUpdateResult = false; + routingConfig = new RoutingConfig(xdsConfig, httpMaxStreamDurationNano, routesData.build()); for (String cluster : deletedClusters) { int count = clusterRefs.get(cluster).refCount.decrementAndGet(); if (count == 0) { - clusterRefs.remove(cluster); + clusterRefs.remove(cluster).close(); shouldUpdateResult = true; } } if (shouldUpdateResult) { - updateResolutionResult(); + updateResolutionResult(xdsConfig); + } + } + + private ClientInterceptor createFilters( + @Nullable List filterConfigs, + VirtualHost virtualHost, + Route route, + @Nullable ClusterWeight weightedCluster) { + if (filterConfigs == null) { + return new PassthroughClientInterceptor(); + } + + Map selectedOverrideConfigs = + new HashMap<>(virtualHost.filterConfigOverrides()); + selectedOverrideConfigs.putAll(route.filterConfigOverrides()); + if (weightedCluster != null) { + selectedOverrideConfigs.putAll(weightedCluster.filterConfigOverrides()); + } + + ImmutableList.Builder filterInterceptors = ImmutableList.builder(); + for (NamedFilterConfig namedFilter : filterConfigs) { + String name = namedFilter.name; + FilterConfig config = namedFilter.filterConfig; + FilterConfig overrideConfig = selectedOverrideConfigs.get(name); + String filterKey = namedFilter.filterStateKey(); + + Filter filter = activeFilters.get(filterKey); + checkNotNull(filter, "activeFilters.get(%s)", filterKey); + ClientInterceptor interceptor = + filter.buildClientInterceptor(config, overrideConfig, scheduler); + + if (interceptor != null) { + filterInterceptors.add(interceptor); + } } + + // Combine interceptors produced by different filters into a single one that executes + // them sequentially. The order is preserved. + return combineInterceptors(filterInterceptors.build()); } - private void cleanUpRoutes(String error) { + private void cleanUpRoutes(Status error) { + routingConfig = new RoutingConfig(error); if (existingClusters != null) { for (String cluster : existingClusters) { int count = clusterRefs.get(cluster).refCount.decrementAndGet(); if (count == 0) { - clusterRefs.remove(cluster); + clusterRefs.remove(cluster).close(); } } existingClusters = null; } - routingConfig = RoutingConfig.empty; + // Without addresses the default LB (normally pick_first) should become TRANSIENT_FAILURE, and - // the config selector handles the error message itself. Once the LB API allows providing - // failure information for addresses yet still providing a service config, the config seector - // could be avoided. - listener.onResult(ResolutionResult.newBuilder() + // the config selector handles the error message itself. + listener.onResult2(ResolutionResult.newBuilder() .setAttributes(Attributes.newBuilder() - .set(InternalConfigSelector.KEY, - new FailingConfigSelector(Status.UNAVAILABLE.withDescription(error))) + .set(InternalConfigSelector.KEY, configSelector) .build()) .setServiceConfig(emptyServiceConfig) .build()); - receivedConfig = true; - } - - private void cleanUpRouteDiscoveryState() { - if (routeDiscoveryState != null) { - String rdsName = routeDiscoveryState.resourceName; - logger.log(XdsLogLevel.INFO, "Stop watching RDS resource {0}", rdsName); - xdsClient.cancelXdsResourceWatch(XdsRouteConfigureResource.getInstance(), rdsName, - routeDiscoveryState); - routeDiscoveryState = null; - } - } - - /** - * Discovery state for RouteConfiguration resource. One instance for each Listener resource - * update. - */ - private class RouteDiscoveryState implements ResourceWatcher { - private final String resourceName; - private final long httpMaxStreamDurationNano; - @Nullable - private final List filterConfigs; - - private RouteDiscoveryState(String resourceName, long httpMaxStreamDurationNano, - @Nullable List filterConfigs) { - this.resourceName = resourceName; - this.httpMaxStreamDurationNano = httpMaxStreamDurationNano; - this.filterConfigs = filterConfigs; - } - - @Override - public void onChanged(final RdsUpdate update) { - if (RouteDiscoveryState.this != routeDiscoveryState) { - return; - } - logger.log(XdsLogLevel.INFO, "Received RDS resource update: {0}", update); - updateRoutes(update.virtualHosts, httpMaxStreamDurationNano, filterConfigs); - } - - @Override - public void onError(final Status error) { - if (RouteDiscoveryState.this != routeDiscoveryState || receivedConfig) { - return; - } - listener.onError(Status.UNAVAILABLE.withCause(error.getCause()).withDescription( - String.format("Unable to load RDS %s. xDS server returned: %s: %s", - resourceName, error.getCode(), error.getDescription()))); - } - - @Override - public void onResourceDoesNotExist(final String resourceName) { - if (RouteDiscoveryState.this != routeDiscoveryState) { - return; - } - String error = "RDS resource does not exist: " + resourceName; - logger.log(XdsLogLevel.INFO, error); - cleanUpRoutes(error); - } } } @@ -887,23 +932,62 @@ public void onResourceDoesNotExist(final String resourceName) { * VirtualHost-level configuration for request routing. */ private static class RoutingConfig { - private final long fallbackTimeoutNano; - final List routes; - // Null if HttpFilter is not supported. - @Nullable final List filterChain; - final Map virtualHostOverrideConfig; - - private static RoutingConfig empty = new RoutingConfig( - 0, Collections.emptyList(), null, Collections.emptyMap()); + final XdsConfig xdsConfig; + final long fallbackTimeoutNano; + final ImmutableList routes; + final Status errorStatus; private RoutingConfig( - long fallbackTimeoutNano, List routes, @Nullable List filterChain, - Map virtualHostOverrideConfig) { + XdsConfig xdsConfig, long fallbackTimeoutNano, ImmutableList routes) { + this.xdsConfig = checkNotNull(xdsConfig, "xdsConfig"); this.fallbackTimeoutNano = fallbackTimeoutNano; - this.routes = routes; - checkArgument(filterChain == null || !filterChain.isEmpty(), "filterChain is empty"); - this.filterChain = filterChain == null ? null : Collections.unmodifiableList(filterChain); - this.virtualHostOverrideConfig = Collections.unmodifiableMap(virtualHostOverrideConfig); + this.routes = checkNotNull(routes, "routes"); + this.errorStatus = null; + } + + private RoutingConfig(Status errorStatus) { + this.xdsConfig = null; + this.fallbackTimeoutNano = 0; + this.routes = null; + this.errorStatus = checkNotNull(errorStatus, "errorStatus"); + checkArgument(!errorStatus.isOk(), "errorStatus should not be okay"); + } + } + + static final class RouteData { + final RouteMatch routeMatch; + /** null implies non-forwarding action. */ + @Nullable + final RouteAction routeAction; + /** + * Only one of these interceptors should be used per-RPC. There are only multiple values in the + * list for weighted clusters, in which case the order of the list mirrors the weighted + * clusters. + */ + final ImmutableList filterChoices; + + RouteData(RouteMatch routeMatch, @Nullable RouteAction routeAction, ClientInterceptor filter) { + this(routeMatch, routeAction, ImmutableList.of(filter)); + } + + RouteData( + RouteMatch routeMatch, + @Nullable RouteAction routeAction, + ImmutableList filterChoices) { + this.routeMatch = checkNotNull(routeMatch, "routeMatch"); + checkArgument( + routeAction == null || !filterChoices.isEmpty(), + "filter may be empty only for non-forwarding action"); + this.routeAction = routeAction; + if (routeAction != null && routeAction.weightedClusters() != null) { + checkArgument( + routeAction.weightedClusters().size() == filterChoices.size(), + "filter choices must match size of weighted clusters"); + } + for (ClientInterceptor filter : filterChoices) { + checkNotNull(filter, "entry in filterChoices is null"); + } + this.filterChoices = checkNotNull(filterChoices, "filterChoices"); } } @@ -913,15 +997,18 @@ private static class ClusterRefState { final String traditionalCluster; @Nullable final RlsPluginConfig rlsPluginConfig; + @Nullable + final XdsConfig.Subscription subscription; private ClusterRefState( AtomicInteger refCount, @Nullable String traditionalCluster, - @Nullable RlsPluginConfig rlsPluginConfig) { + @Nullable RlsPluginConfig rlsPluginConfig, @Nullable XdsConfig.Subscription subscription) { this.refCount = refCount; checkArgument(traditionalCluster == null ^ rlsPluginConfig == null, "There must be exactly one non-null value in traditionalCluster and pluginConfig"); this.traditionalCluster = traditionalCluster; this.rlsPluginConfig = rlsPluginConfig; + this.subscription = subscription; } private Map toLbPolicy() { @@ -934,19 +1021,97 @@ private ClusterRefState( .put("routeLookupConfig", rlsPluginConfig.config()) .put( "childPolicy", - ImmutableList.of(ImmutableMap.of(XdsLbPolicies.CDS_POLICY_NAME, ImmutableMap.of()))) + ImmutableList.of(ImmutableMap.of(XdsLbPolicies.CDS_POLICY_NAME, ImmutableMap.of( + "is_dynamic", true)))) .put("childPolicyConfigTargetFieldName", "cluster") .buildOrThrow(); return ImmutableMap.of("rls_experimental", rlsConfig); } } - static ClusterRefState forCluster(AtomicInteger refCount, String name) { - return new ClusterRefState(refCount, name, null); + private void close() { + if (subscription != null) { + subscription.close(); + } + } + + static ClusterRefState forCluster( + AtomicInteger refCount, String name, XdsConfig.Subscription subscription) { + return new ClusterRefState(refCount, name, null, checkNotNull(subscription, "subscription")); + } + + static ClusterRefState forRlsPlugin( + AtomicInteger refCount, + RlsPluginConfig rlsPluginConfig) { + return new ClusterRefState(refCount, null, rlsPluginConfig, null); + } + } + + /** An ObjectPool, except it can throw an exception. */ + private interface XdsClientPool { + XdsClient getObject() throws XdsInitializationException; + + XdsClient returnObject(XdsClient xdsClient); + } + + private static final class BootstrappingXdsClientPool implements XdsClientPool { + private final XdsClientPoolFactory xdsClientPoolFactory; + private final String target; + private final @Nullable Map bootstrapOverride; + private final MetricRecorder metricRecorder; + private ObjectPool xdsClientPool; + + BootstrappingXdsClientPool( + XdsClientPoolFactory xdsClientPoolFactory, + String target, + @Nullable Map bootstrapOverride, + MetricRecorder metricRecorder) { + this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + this.target = checkNotNull(target, "target"); + this.bootstrapOverride = bootstrapOverride; + this.metricRecorder = checkNotNull(metricRecorder, "metricRecorder"); } - static ClusterRefState forRlsPlugin(AtomicInteger refCount, RlsPluginConfig rlsPluginConfig) { - return new ClusterRefState(refCount, null, rlsPluginConfig); + @Override + public XdsClient getObject() throws XdsInitializationException { + if (xdsClientPool == null) { + BootstrapInfo bootstrapInfo; + if (bootstrapOverride == null) { + bootstrapInfo = GrpcBootstrapperImpl.defaultBootstrap(); + } else { + bootstrapInfo = new GrpcBootstrapperImpl().bootstrap(bootstrapOverride); + } + this.xdsClientPool = + xdsClientPoolFactory.getOrCreate(target, bootstrapInfo, metricRecorder); + } + return xdsClientPool.getObject(); + } + + @Override + public XdsClient returnObject(XdsClient xdsClient) { + return xdsClientPool.returnObject(xdsClient); + } + } + + private static final class SupplierXdsClientPool implements XdsClientPool { + private final Supplier xdsClientSupplier; + + SupplierXdsClientPool(Supplier xdsClientSupplier) { + this.xdsClientSupplier = checkNotNull(xdsClientSupplier, "xdsClientSupplier"); + } + + @Override + public XdsClient getObject() throws XdsInitializationException { + XdsClient xdsClient = xdsClientSupplier.get(); + if (xdsClient == null) { + throw new XdsInitializationException("Caller failed to initialize XDS_CLIENT_SUPPLIER"); + } + return xdsClient; + } + + @Override + public XdsClient returnObject(XdsClient xdsClient) { + return null; } } } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java index 8d0e59eaa91..51b1ff49bf0 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java @@ -22,6 +22,8 @@ import io.grpc.Internal; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; +import io.grpc.Uri; +import io.grpc.xds.client.XdsClient; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; @@ -29,6 +31,7 @@ import java.util.Collections; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; import javax.annotation.Nullable; /** @@ -43,6 +46,13 @@ */ @Internal public final class XdsNameResolverProvider extends NameResolverProvider { + /** + * If provided, the suppler must return non-null when lb.start() is called (which implies not + * throwing), and the XdsClient must remain alive until lb.shutdown() returns. It may only be + * called from the synchronization context. + */ + public static final Args.Key> XDS_CLIENT_SUPPLIER = + Args.Key.create("io.grpc.xds.XdsNameResolverProvider.XDS_CLIENT_SUPPLIER"); private static final String SCHEME = "xds"; private final String scheme; @@ -77,15 +87,43 @@ public XdsNameResolver newNameResolver(URI targetUri, Args args) { targetPath, targetUri); String name = targetPath.substring(1); - return new XdsNameResolver( - targetUri, name, args.getOverrideAuthority(), - args.getServiceConfigParser(), args.getSynchronizationContext(), - args.getScheduledExecutorService(), - bootstrapOverride); + // TODO(jdcormie): java.net.URI#getAuthority incorrectly returns null for both xds:///service + // and xds:/service. This doesn't matter for now since XdsNameResolver treats them the same + // anyway and all this code will go away once newNameResolver(io.grpc.Uri) launches. + String targetAuthority = targetUri.getAuthority(); + return newNameResolver(targetUri.toString(), targetAuthority, name, args); + } + return null; + } + + @Override + public XdsNameResolver newNameResolver(Uri targetUri, Args args) { + if (scheme.equals(targetUri.getScheme())) { + Preconditions.checkArgument( + targetUri.isPathAbsolute(), + "the path component of the target (%s) must start with '/'", + targetUri); + return newNameResolver( + targetUri.toString(), targetUri.getAuthority(), targetUri.getPath().substring(1), args); } return null; } + private XdsNameResolver newNameResolver( + String targetUri, String targetAuthority, String name, Args args) { + return new XdsNameResolver( + targetUri.toString(), + targetAuthority, + name, + args.getOverrideAuthority(), + args.getServiceConfigParser(), + args.getSynchronizationContext(), + args.getScheduledExecutorService(), + bootstrapOverride, + args.getMetricRecorder(), + args); + } + @Override public String getDefaultScheme() { return scheme; diff --git a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java index 0a3d1406dac..24ec0659b42 100644 --- a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java @@ -68,6 +68,9 @@ import javax.annotation.Nullable; class XdsRouteConfigureResource extends XdsResourceType { + + private static final boolean isXdsAuthorityRewriteEnabled = GrpcUtil.getFlag( + "GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", true); @VisibleForTesting static boolean enableRouteLookup = GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_RLS_LB", true); @@ -75,6 +78,8 @@ class XdsRouteConfigureResource extends XdsResourceType { "type.googleapis.com/envoy.config.route.v3.RouteConfiguration"; private static final String TYPE_URL_FILTER_CONFIG = "type.googleapis.com/envoy.config.route.v3.FilterConfig"; + @VisibleForTesting + static final String HASH_POLICY_FILTER_STATE_KEY = "io.grpc.channel_id"; // TODO(zdapeng): need to discuss how to handle unsupported values. private static final Set SUPPORTED_RETRYABLE_CODES = Collections.unmodifiableSet(EnumSet.of( @@ -128,17 +133,17 @@ protected RdsUpdate doParse(XdsResourceType.Args args, Message unpackedMessage) throw new ResourceInvalidException("Invalid message type: " + unpackedMessage.getClass()); } return processRouteConfiguration( - (RouteConfiguration) unpackedMessage, FilterRegistry.getDefaultRegistry()); + (RouteConfiguration) unpackedMessage, FilterRegistry.getDefaultRegistry(), args); } private static RdsUpdate processRouteConfiguration( - RouteConfiguration routeConfig, FilterRegistry filterRegistry) + RouteConfiguration routeConfig, FilterRegistry filterRegistry, XdsResourceType.Args args) throws ResourceInvalidException { - return new RdsUpdate(extractVirtualHosts(routeConfig, filterRegistry)); + return new RdsUpdate(extractVirtualHosts(routeConfig, filterRegistry, args)); } static List extractVirtualHosts( - RouteConfiguration routeConfig, FilterRegistry filterRegistry) + RouteConfiguration routeConfig, FilterRegistry filterRegistry, XdsResourceType.Args args) throws ResourceInvalidException { Map pluginConfigMap = new HashMap<>(); ImmutableSet.Builder optionalPlugins = ImmutableSet.builder(); @@ -164,7 +169,7 @@ static List extractVirtualHosts( : routeConfig.getVirtualHostsList()) { StructOrError virtualHost = parseVirtualHost(virtualHostProto, filterRegistry, pluginConfigMap, - optionalPlugins.build()); + optionalPlugins.build(), args); if (virtualHost.getErrorDetail() != null) { throw new ResourceInvalidException( "RouteConfiguration contains invalid virtual host: " + virtualHost.getErrorDetail()); @@ -177,12 +182,12 @@ static List extractVirtualHosts( private static StructOrError parseVirtualHost( io.envoyproxy.envoy.config.route.v3.VirtualHost proto, FilterRegistry filterRegistry, Map pluginConfigMap, - Set optionalPlugins) { + Set optionalPlugins, XdsResourceType.Args args) { String name = proto.getName(); List routes = new ArrayList<>(proto.getRoutesCount()); for (io.envoyproxy.envoy.config.route.v3.Route routeProto : proto.getRoutesList()) { StructOrError route = parseRoute( - routeProto, filterRegistry, pluginConfigMap, optionalPlugins); + routeProto, filterRegistry, pluginConfigMap, optionalPlugins, args); if (route == null) { continue; } @@ -240,8 +245,8 @@ static StructOrError> parseOverrideFilterConfigs( return StructOrError.fromError( "FilterConfig [" + name + "] contains invalid proto: " + e); } - Filter filter = filterRegistry.get(typeUrl); - if (filter == null) { + Filter.Provider provider = filterRegistry.get(typeUrl); + if (provider == null) { if (isOptional) { continue; } @@ -249,7 +254,7 @@ static StructOrError> parseOverrideFilterConfigs( "HttpFilter [" + name + "](" + typeUrl + ") is required but unsupported"); } ConfigOrError filterConfig = - filter.parseFilterConfigOverride(rawConfig); + provider.parseFilterConfigOverride(rawConfig); if (filterConfig.errorDetail != null) { return StructOrError.fromError( "Invalid filter config for HttpFilter [" + name + "]: " + filterConfig.errorDetail); @@ -264,7 +269,7 @@ static StructOrError> parseOverrideFilterConfigs( static StructOrError parseRoute( io.envoyproxy.envoy.config.route.v3.Route proto, FilterRegistry filterRegistry, Map pluginConfigMap, - Set optionalPlugins) { + Set optionalPlugins, XdsResourceType.Args args) { StructOrError routeMatch = parseRouteMatch(proto.getMatch()); if (routeMatch == null) { return null; @@ -288,7 +293,7 @@ static StructOrError parseRoute( case ROUTE: StructOrError routeAction = parseRouteAction(proto.getRoute(), filterRegistry, pluginConfigMap, - optionalPlugins); + optionalPlugins, args); if (routeAction == null) { return null; } @@ -414,7 +419,7 @@ static StructOrError parseHeaderMatcher( static StructOrError parseRouteAction( io.envoyproxy.envoy.config.route.v3.RouteAction proto, FilterRegistry filterRegistry, Map pluginConfigMap, - Set optionalPlugins) { + Set optionalPlugins, XdsResourceType.Args args) { Long timeoutNano = null; if (proto.hasMaxStreamDuration()) { io.envoyproxy.envoy.config.route.v3.RouteAction.MaxStreamDuration maxStreamDuration @@ -446,8 +451,7 @@ static StructOrError parseRouteAction( config.getHeader(); Pattern regEx = null; String regExSubstitute = null; - if (headerCfg.hasRegexRewrite() && headerCfg.getRegexRewrite().hasPattern() - && headerCfg.getRegexRewrite().getPattern().hasGoogleRe2()) { + if (headerCfg.hasRegexRewrite() && headerCfg.getRegexRewrite().hasPattern()) { regEx = Pattern.compile(headerCfg.getRegexRewrite().getPattern().getRegex()); regExSubstitute = headerCfg.getRegexRewrite().getSubstitution(); } @@ -470,7 +474,9 @@ static StructOrError parseRouteAction( switch (proto.getClusterSpecifierCase()) { case CLUSTER: return StructOrError.fromStruct(RouteAction.forCluster( - proto.getCluster(), hashPolicies, timeoutNano, retryPolicy)); + proto.getCluster(), hashPolicies, timeoutNano, retryPolicy, + isXdsAuthorityRewriteEnabled && args.getServerInfo().isTrustedXdsServer() + && proto.getAutoHostRewrite().getValue())); case CLUSTER_HEADER: return null; case WEIGHTED_CLUSTERS: @@ -489,8 +495,9 @@ static StructOrError parseRouteAction( return StructOrError.fromError("RouteAction contains invalid ClusterWeight: " + clusterWeightOrError.getErrorDetail()); } - clusterWeightSum += clusterWeight.getWeight().getValue(); - weightedClusters.add(clusterWeightOrError.getStruct()); + ClusterWeight parsedWeight = clusterWeightOrError.getStruct(); + clusterWeightSum += parsedWeight.weight(); + weightedClusters.add(parsedWeight); } if (clusterWeightSum <= 0) { return StructOrError.fromError("Sum of cluster weights should be above 0."); @@ -502,7 +509,9 @@ static StructOrError parseRouteAction( UnsignedInteger.MAX_VALUE.longValue(), clusterWeightSum)); } return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forWeightedClusters( - weightedClusters, hashPolicies, timeoutNano, retryPolicy)); + weightedClusters, hashPolicies, timeoutNano, retryPolicy, + isXdsAuthorityRewriteEnabled && args.getServerInfo().isTrustedXdsServer() + && proto.getAutoHostRewrite().getValue())); case CLUSTER_SPECIFIER_PLUGIN: if (enableRouteLookup) { String pluginName = proto.getClusterSpecifierPlugin(); @@ -517,7 +526,9 @@ static StructOrError parseRouteAction( } NamedPluginConfig namedPluginConfig = NamedPluginConfig.create(pluginName, pluginConfig); return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forClusterSpecifierPlugin( - namedPluginConfig, hashPolicies, timeoutNano, retryPolicy)); + namedPluginConfig, hashPolicies, timeoutNano, retryPolicy, + isXdsAuthorityRewriteEnabled && args.getServerInfo().isTrustedXdsServer() + && proto.getAutoHostRewrite().getValue())); } else { return null; } @@ -597,7 +608,9 @@ static StructOrError parseClusterWe + overrideConfigs.getErrorDetail()); } return StructOrError.fromStruct(VirtualHost.Route.RouteAction.ClusterWeight.create( - proto.getName(), proto.getWeight().getValue(), overrideConfigs.getStruct())); + proto.getName(), + Integer.toUnsignedLong(proto.getWeight().getValue()), + overrideConfigs.getStruct())); } @Nullable // null if the plugin is not supported, but it's marked as optional. diff --git a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java index b75d5755f6e..4a4fb71aa84 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java @@ -19,8 +19,8 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import static io.grpc.xds.InternalXdsAttributes.ATTR_DRAIN_GRACE_NANOS; -import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; +import static io.grpc.xds.XdsAttributes.ATTR_DRAIN_GRACE_NANOS; +import static io.grpc.xds.XdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; import com.google.common.annotations.VisibleForTesting; import com.google.errorprone.annotations.DoNotCall; @@ -55,6 +55,7 @@ public final class XdsServerBuilder extends ForwardingServerBuilder bootstrapOverride; private long drainGraceTime = 10; private TimeUnit drainGraceTimeUnit = TimeUnit.MINUTES; @@ -127,7 +128,7 @@ public Server build() { } InternalNettyServerBuilder.eagAttributes(delegate, builder.build()); return new XdsServerWrapper("0.0.0.0:" + port, delegate, xdsServingStatusListener, - filterChainSelectorManager, xdsClientPoolFactory, filterRegistry); + filterChainSelectorManager, xdsClientPoolFactory, bootstrapOverride, filterRegistry); } @VisibleForTesting @@ -140,11 +141,10 @@ XdsServerBuilder xdsClientPoolFactory(XdsClientPoolFactory xdsClientPoolFactory) * Allows providing bootstrap override, useful for testing. */ public XdsServerBuilder overrideBootstrapForTest(Map bootstrapOverride) { - checkNotNull(bootstrapOverride, "bootstrapOverride"); + this.bootstrapOverride = checkNotNull(bootstrapOverride, "bootstrapOverride"); if (this.xdsClientPoolFactory == SharedXdsClientPoolProvider.getDefaultProvider()) { this.xdsClientPoolFactory = new SharedXdsClientPoolProvider(); } - this.xdsClientPoolFactory.setBootstrapOverride(bootstrapOverride); return this; } diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index dfb7c4fb7db..5529f96c7a2 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -24,11 +24,15 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.net.HostAndPort; +import com.google.common.net.InetAddresses; import com.google.common.util.concurrent.SettableFuture; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.grpc.Attributes; import io.grpc.InternalServerInterceptors; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MetricRecorder; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.ServerCall; @@ -38,6 +42,7 @@ import io.grpc.ServerServiceDefinition; import io.grpc.Status; import io.grpc.StatusException; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.GrpcUtil; @@ -46,20 +51,20 @@ import io.grpc.xds.EnvoyServerProtoData.FilterChain; import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; -import io.grpc.xds.Filter.ServerInterceptorBuilder; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsClient.ResourceWatcher; import io.grpc.xds.internal.security.SslContextProviderSupplier; import java.io.IOException; +import java.net.InetAddress; import java.net.SocketAddress; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -100,6 +105,7 @@ public void uncaughtException(Thread t, Throwable e) { private final FilterRegistry filterRegistry; private final ThreadSafeRandom random = ThreadSafeRandomImpl.instance; private final XdsClientPoolFactory xdsClientPoolFactory; + private final @Nullable Map bootstrapOverride; private final XdsServingStatusListener listener; private final FilterChainSelectorManager filterChainSelectorManager; private final AtomicBoolean started = new AtomicBoolean(false); @@ -114,15 +120,31 @@ public void uncaughtException(Thread t, Throwable e) { private DiscoveryState discoveryState; private volatile Server delegate; + // Must be accessed in syncContext. + // Filter instances are unique per Server, per FilterChain, and per filter's name+typeUrl. + // FilterChain.name -> filter_instance>. + private final HashMap> activeFilters = new HashMap<>(); + // Default filter chain Filter instances are unique per Server, and per filter's name+typeUrl. + // NamedFilterConfig.filterStateKey -> filter_instance. + private final HashMap activeFiltersDefaultChain = new HashMap<>(); + XdsServerWrapper( String listenerAddress, ServerBuilder delegateBuilder, XdsServingStatusListener listener, FilterChainSelectorManager filterChainSelectorManager, XdsClientPoolFactory xdsClientPoolFactory, + @Nullable Map bootstrapOverride, FilterRegistry filterRegistry) { - this(listenerAddress, delegateBuilder, listener, filterChainSelectorManager, - xdsClientPoolFactory, filterRegistry, SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE)); + this( + listenerAddress, + delegateBuilder, + listener, + filterChainSelectorManager, + xdsClientPoolFactory, + bootstrapOverride, + filterRegistry, + SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE)); sharedTimeService = true; } @@ -133,6 +155,7 @@ public void uncaughtException(Thread t, Throwable e) { XdsServingStatusListener listener, FilterChainSelectorManager filterChainSelectorManager, XdsClientPoolFactory xdsClientPoolFactory, + @Nullable Map bootstrapOverride, FilterRegistry filterRegistry, ScheduledExecutorService timeService) { this.listenerAddress = checkNotNull(listenerAddress, "listenerAddress"); @@ -142,6 +165,7 @@ public void uncaughtException(Thread t, Throwable e) { this.filterChainSelectorManager = checkNotNull(filterChainSelectorManager, "filterChainSelectorManager"); this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + this.bootstrapOverride = bootstrapOverride; this.timeService = checkNotNull(timeService, "timeService"); this.filterRegistry = checkNotNull(filterRegistry,"filterRegistry"); this.delegate = delegateBuilder.build(); @@ -171,7 +195,14 @@ public void run() { private void internalStart() { try { - xdsClientPool = xdsClientPoolFactory.getOrCreate(""); + BootstrapInfo bootstrapInfo; + if (bootstrapOverride == null) { + bootstrapInfo = GrpcBootstrapperImpl.defaultBootstrap(); + } else { + bootstrapInfo = new GrpcBootstrapperImpl().bootstrap(bootstrapOverride); + } + xdsClientPool = xdsClientPoolFactory.getOrCreate( + "#server", bootstrapInfo, new MetricRecorder() {}); } catch (Exception e) { StatusException statusException = Status.UNAVAILABLE.withDescription( "Failed to initialize xDS").withCause(e).asException(); @@ -371,25 +402,55 @@ private DiscoveryState(String resourceName) { } @Override - public void onChanged(final LdsUpdate update) { + public void onResourceChanged(final StatusOr update) { if (stopped) { return; } - logger.log(Level.FINEST, "Received Lds update {0}", update); - checkNotNull(update.listener(), "update"); + + if (!update.hasValue()) { + Status status = update.getStatus(); + StatusException statusException = Status.UNAVAILABLE.withDescription( + String.format("Listener %s unavailable: %s", resourceName, status.getDescription())) + .withCause(status.asException()) + .asException(); + handleConfigNotFoundOrMismatch(statusException); + return; + } + + final LdsUpdate ldsUpdate = update.getValue(); + logger.log(Level.FINEST, "Received Lds update {0}", ldsUpdate); + if (ldsUpdate.listener() == null) { + handleConfigNotFoundOrMismatch( + Status.NOT_FOUND.withDescription("Listener is null in LdsUpdate").asException()); + return; + } + String ldsAddress = ldsUpdate.listener().address(); + if (ldsAddress == null || ldsUpdate.listener().protocol() != Protocol.TCP + || !ipAddressesMatch(ldsAddress)) { + handleConfigNotFoundOrMismatch( + Status.UNKNOWN.withDescription( + String.format( + "Listener address mismatch: expected %s, but got %s.", + listenerAddress, ldsAddress)).asException()); + return; + } + if (!pendingRds.isEmpty()) { // filter chain state has not yet been applied to filterChainSelectorManager and there - // are two sets of sslContextProviderSuppliers, so we release the old ones. releaseSuppliersInFlight(); pendingRds.clear(); } - filterChains = update.listener().filterChains(); - defaultFilterChain = update.listener().defaultFilterChain(); + + filterChains = ldsUpdate.listener().filterChains(); + defaultFilterChain = ldsUpdate.listener().defaultFilterChain(); + updateActiveFilters(); + List allFilterChains = filterChains; if (defaultFilterChain != null) { allFilterChains = new ArrayList<>(filterChains); allFilterChains.add(defaultFilterChain); } + Set allRds = new HashSet<>(); for (FilterChain filterChain : allFilterChains) { HttpConnectionManager hcm = filterChain.httpConnectionManager(); @@ -407,6 +468,7 @@ public void onChanged(final LdsUpdate update) { allRds.add(hcm.rdsName()); } } + for (Map.Entry entry: routeDiscoveryStates.entrySet()) { if (!allRds.contains(entry.getKey())) { xdsClient.cancelXdsResourceWatch(XdsRouteConfigureResource.getInstance(), @@ -420,31 +482,38 @@ public void onChanged(final LdsUpdate update) { } @Override - public void onResourceDoesNotExist(final String resourceName) { + public void onAmbientError(final Status error) { if (stopped) { return; } - StatusException statusException = Status.UNAVAILABLE.withDescription( - "Listener " + resourceName + " unavailable").asException(); - handleConfigNotFound(statusException); - } + String description = error.getDescription() == null ? "" : error.getDescription() + " "; + Status errorWithNodeId = error.withDescription( + description + "xDS node ID: " + xdsClient.getBootstrapInfo().node().getId()); + logger.log(Level.FINE, "Error from XdsClient", errorWithNodeId); - @Override - public void onError(final Status error) { - if (stopped) { - return; - } - logger.log(Level.FINE, "Error from XdsClient", error); if (!isServing) { - listener.onNotServing(error.asException()); + listener.onNotServing(errorWithNodeId.asException()); } } + private boolean ipAddressesMatch(String ldsAddress) { + HostAndPort ldsAddressHnP = HostAndPort.fromString(ldsAddress); + HostAndPort listenerAddressHnP = HostAndPort.fromString(listenerAddress); + if (!ldsAddressHnP.hasPort() || !listenerAddressHnP.hasPort() + || ldsAddressHnP.getPort() != listenerAddressHnP.getPort()) { + return false; + } + InetAddress listenerIp = InetAddresses.forString(listenerAddressHnP.getHost()); + InetAddress ldsIp = InetAddresses.forString(ldsAddressHnP.getHost()); + return listenerIp.equals(ldsIp); + } + private void shutdown() { stopped = true; cleanUpRouteDiscoveryStates(); logger.log(Level.FINE, "Stop watching LDS resource {0}", resourceName); xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(), resourceName, this); + shutdownActiveFilters(); List toRelease = getSuppliersInUse(); filterChainSelectorManager.updateSelector(FilterChainSelector.NO_FILTER_CHAIN); for (SslContextProviderSupplier s: toRelease) { @@ -454,81 +523,184 @@ private void shutdown() { } private void updateSelector() { - Map> filterChainRouting = new HashMap<>(); + // This is regenerated in generateRoutingConfig() calls below. savedRdsRoutingConfigRef.clear(); + + // Prepare server routing config map. + ImmutableMap.Builder> routingConfigs = + ImmutableMap.builder(); for (FilterChain filterChain: filterChains) { - filterChainRouting.put(filterChain, generateRoutingConfig(filterChain)); + HashMap chainFilters = activeFilters.get(filterChain.name()); + routingConfigs.put(filterChain, generateRoutingConfig(filterChain, chainFilters)); } - FilterChainSelector selector = new FilterChainSelector( - Collections.unmodifiableMap(filterChainRouting), - defaultFilterChain == null ? null : defaultFilterChain.sslContextProviderSupplier(), - defaultFilterChain == null ? new AtomicReference() : - generateRoutingConfig(defaultFilterChain)); - List toRelease = getSuppliersInUse(); + + // Prepare the new selector. + FilterChainSelector selector; + if (defaultFilterChain != null) { + selector = new FilterChainSelector( + routingConfigs.build(), + defaultFilterChain.sslContextProviderSupplier(), + generateRoutingConfig(defaultFilterChain, activeFiltersDefaultChain)); + } else { + selector = new FilterChainSelector(routingConfigs.build()); + } + + // Prepare the list of current selector's resources to close later. + List oldSslSuppliers = getSuppliersInUse(); + + // Swap the selectors, initiate a graceful shutdown of the old one. logger.log(Level.FINEST, "Updating selector {0}", selector); filterChainSelectorManager.updateSelector(selector); - for (SslContextProviderSupplier e: toRelease) { - e.close(); + + // Release old resources. + for (SslContextProviderSupplier supplier: oldSslSuppliers) { + supplier.close(); } + + // Now that we have valid Transport Socket config, we can start/restart listening on a port. startDelegateServer(); } - private AtomicReference generateRoutingConfig(FilterChain filterChain) { + // called in syncContext + private void updateActiveFilters() { + Set removedChains = new HashSet<>(activeFilters.keySet()); + for (FilterChain filterChain: filterChains) { + removedChains.remove(filterChain.name()); + updateActiveFiltersForChain( + activeFilters.computeIfAbsent(filterChain.name(), k -> new HashMap<>()), + filterChain.httpConnectionManager().httpFilterConfigs()); + } + + // Shutdown all filters of chains missing from the LDS. + for (String chainToShutdown : removedChains) { + HashMap filtersToShutdown = activeFilters.get(chainToShutdown); + checkNotNull(filtersToShutdown, "filtersToShutdown of chain %s", chainToShutdown); + updateActiveFiltersForChain(filtersToShutdown, null); + activeFilters.remove(chainToShutdown); + } + + // Default chain. + ImmutableList defaultChainConfigs = null; + if (defaultFilterChain != null) { + defaultChainConfigs = defaultFilterChain.httpConnectionManager().httpFilterConfigs(); + } + updateActiveFiltersForChain(activeFiltersDefaultChain, defaultChainConfigs); + } + + // called in syncContext + private void shutdownActiveFilters() { + for (HashMap chainFilters : activeFilters.values()) { + checkNotNull(chainFilters, "chainFilters"); + updateActiveFiltersForChain(chainFilters, null); + } + activeFilters.clear(); + updateActiveFiltersForChain(activeFiltersDefaultChain, null); + } + + // called in syncContext + private void updateActiveFiltersForChain( + Map chainFilters, @Nullable List filterConfigs) { + if (filterConfigs == null) { + filterConfigs = ImmutableList.of(); + } + + Set filtersToShutdown = new HashSet<>(chainFilters.keySet()); + for (NamedFilterConfig namedFilter : filterConfigs) { + String typeUrl = namedFilter.filterConfig.typeUrl(); + String filterKey = namedFilter.filterStateKey(); + + Filter.Provider provider = filterRegistry.get(typeUrl); + checkNotNull(provider, "provider %s", typeUrl); + Filter filter = chainFilters.computeIfAbsent( + filterKey, k -> provider.newInstance(namedFilter.name)); + checkNotNull(filter, "filter %s", filterKey); + filtersToShutdown.remove(filterKey); + } + + // Shutdown filters not present in current HCM. + for (String filterKey : filtersToShutdown) { + Filter filterToShutdown = chainFilters.remove(filterKey); + checkNotNull(filterToShutdown, "filterToShutdown %s", filterKey); + filterToShutdown.close(); + } + } + + private AtomicReference generateRoutingConfig( + FilterChain filterChain, Map chainFilters) { HttpConnectionManager hcm = filterChain.httpConnectionManager(); - if (hcm.virtualHosts() != null) { - ImmutableMap interceptors = generatePerRouteInterceptors( - hcm.httpFilterConfigs(), hcm.virtualHosts()); - return new AtomicReference<>(ServerRoutingConfig.create(hcm.virtualHosts(),interceptors)); + ServerRoutingConfig routingConfig; + + // Inlined routes. + ImmutableList vhosts = hcm.virtualHosts(); + if (vhosts != null) { + routingConfig = ServerRoutingConfig.create(vhosts, + generatePerRouteInterceptors(hcm.httpFilterConfigs(), vhosts, chainFilters)); + return new AtomicReference<>(routingConfig); + } + + // Routes from RDS. + RouteDiscoveryState rds = routeDiscoveryStates.get(hcm.rdsName()); + checkNotNull(rds, "rds"); + + ImmutableList savedVhosts = rds.savedVirtualHosts; + if (savedVhosts != null) { + routingConfig = ServerRoutingConfig.create(savedVhosts, + generatePerRouteInterceptors(hcm.httpFilterConfigs(), savedVhosts, chainFilters)); } else { - RouteDiscoveryState rds = routeDiscoveryStates.get(hcm.rdsName()); - checkNotNull(rds, "rds"); - AtomicReference serverRoutingConfigRef = new AtomicReference<>(); - if (rds.savedVirtualHosts != null) { - ImmutableMap interceptors = generatePerRouteInterceptors( - hcm.httpFilterConfigs(), rds.savedVirtualHosts); - ServerRoutingConfig serverRoutingConfig = - ServerRoutingConfig.create(rds.savedVirtualHosts, interceptors); - serverRoutingConfigRef.set(serverRoutingConfig); - } else { - serverRoutingConfigRef.set(ServerRoutingConfig.FAILING_ROUTING_CONFIG); - } - savedRdsRoutingConfigRef.put(filterChain, serverRoutingConfigRef); - return serverRoutingConfigRef; + routingConfig = ServerRoutingConfig.FAILING_ROUTING_CONFIG; } + AtomicReference routingConfigRef = new AtomicReference<>(routingConfig); + savedRdsRoutingConfigRef.put(filterChain, routingConfigRef); + return routingConfigRef; } private ImmutableMap generatePerRouteInterceptors( - List namedFilterConfigs, List virtualHosts) { + @Nullable List filterConfigs, + List virtualHosts, + Map chainFilters) { + syncContext.throwIfNotInThisSynchronizationContext(); + + checkNotNull(chainFilters, "chainFilters"); ImmutableMap.Builder perRouteInterceptors = new ImmutableMap.Builder<>(); + for (VirtualHost virtualHost : virtualHosts) { for (Route route : virtualHost.routes()) { - List filterInterceptors = new ArrayList<>(); - Map selectedOverrideConfigs = - new HashMap<>(virtualHost.filterConfigOverrides()); - selectedOverrideConfigs.putAll(route.filterConfigOverrides()); - if (namedFilterConfigs != null) { - for (NamedFilterConfig namedFilterConfig : namedFilterConfigs) { - FilterConfig filterConfig = namedFilterConfig.filterConfig; - Filter filter = filterRegistry.get(filterConfig.typeUrl()); - if (filter instanceof ServerInterceptorBuilder) { - ServerInterceptor interceptor = - ((ServerInterceptorBuilder) filter).buildServerInterceptor( - filterConfig, selectedOverrideConfigs.get(namedFilterConfig.name)); - if (interceptor != null) { - filterInterceptors.add(interceptor); - } - } else { - logger.log(Level.WARNING, "HttpFilterConfig(type URL: " - + filterConfig.typeUrl() + ") is not supported on server-side. " - + "Probably a bug at ClientXdsClient verification."); - } + // Short circuit. + if (filterConfigs == null) { + perRouteInterceptors.put(route, noopInterceptor); + continue; + } + + // Override vhost filter configs with more specific per-route configs. + Map perRouteOverrides = ImmutableMap.builder() + .putAll(virtualHost.filterConfigOverrides()) + .putAll(route.filterConfigOverrides()) + .buildKeepingLast(); + + // Interceptors for this vhost/route combo. + List interceptors = new ArrayList<>(filterConfigs.size()); + for (NamedFilterConfig namedFilter : filterConfigs) { + String name = namedFilter.name; + FilterConfig config = namedFilter.filterConfig; + FilterConfig overrideConfig = perRouteOverrides.get(name); + String filterKey = namedFilter.filterStateKey(); + + Filter filter = chainFilters.get(filterKey); + checkNotNull(filter, "chainFilters.get(%s)", filterKey); + ServerInterceptor interceptor = filter.buildServerInterceptor(config, overrideConfig); + + if (interceptor != null) { + interceptors.add(interceptor); } } - ServerInterceptor interceptor = combineInterceptors(filterInterceptors); - perRouteInterceptors.put(route, interceptor); + + // Combine interceptors produced by different filters into a single one that executes + // them sequentially. The order is preserved. + perRouteInterceptors.put(route, combineInterceptors(interceptors)); } } + return perRouteInterceptors.buildOrThrow(); } @@ -553,8 +725,9 @@ public Listener interceptCall(ServerCall call, }; } - private void handleConfigNotFound(StatusException exception) { + private void handleConfigNotFoundOrMismatch(StatusException exception) { cleanUpRouteDiscoveryStates(); + shutdownActiveFilters(); List toRelease = getSuppliersInUse(); filterChainSelectorManager.updateSelector(FilterChainSelector.NO_FILTER_CHAIN); for (SslContextProviderSupplier s: toRelease) { @@ -623,72 +796,65 @@ private RouteDiscoveryState(String resourceName) { } @Override - public void onChanged(final RdsUpdate update) { - syncContext.execute(new Runnable() { - @Override - public void run() { - if (!routeDiscoveryStates.containsKey(resourceName)) { - return; - } - if (savedVirtualHosts == null && !isPending) { - logger.log(Level.WARNING, "Received valid Rds {0} configuration.", resourceName); - } - savedVirtualHosts = ImmutableList.copyOf(update.virtualHosts); - updateRdsRoutingConfig(); - maybeUpdateSelector(); + public void onResourceChanged(final StatusOr update) { + syncContext.execute(() -> { + if (!routeDiscoveryStates.containsKey(resourceName)) { + return; // Watcher has been cancelled. } - }); - } - @Override - public void onResourceDoesNotExist(final String resourceName) { - syncContext.execute(new Runnable() { - @Override - public void run() { - if (!routeDiscoveryStates.containsKey(resourceName)) { - return; + if (update.hasValue()) { + if (savedVirtualHosts == null && !isPending) { + logger.log(Level.WARNING, "Received valid Rds {0} configuration.", resourceName); } - logger.log(Level.WARNING, "Rds {0} unavailable", resourceName); + savedVirtualHosts = ImmutableList.copyOf(update.getValue().virtualHosts); + } else { + logger.log(Level.WARNING, "Rds {0} unavailable: {1}", + new Object[]{resourceName, update.getStatus()}); savedVirtualHosts = null; - updateRdsRoutingConfig(); - maybeUpdateSelector(); } + // In both cases, a change has occurred that requires a config update. + updateRdsRoutingConfig(); + maybeUpdateSelector(); }); } @Override - public void onError(final Status error) { - syncContext.execute(new Runnable() { - @Override - public void run() { - if (!routeDiscoveryStates.containsKey(resourceName)) { - return; - } - logger.log(Level.WARNING, "Error loading RDS resource {0} from XdsClient: {1}.", - new Object[]{resourceName, error}); - maybeUpdateSelector(); + public void onAmbientError(final Status error) { + syncContext.execute(() -> { + if (!routeDiscoveryStates.containsKey(resourceName)) { + return; // Watcher has been cancelled. } + String description = error.getDescription() == null ? "" : error.getDescription() + " "; + Status errorWithNodeId = error.withDescription( + description + "xDS node ID: " + xdsClient.getBootstrapInfo().node().getId()); + logger.log(Level.WARNING, "Error loading RDS resource {0} from XdsClient: {1}.", + new Object[]{resourceName, errorWithNodeId}); + + // Per gRFC A88, ambient errors should not trigger a configuration change. + // Therefore, we do NOT call maybeUpdateSelector() here. }); } private void updateRdsRoutingConfig() { for (FilterChain filterChain : savedRdsRoutingConfigRef.keySet()) { - if (resourceName.equals(filterChain.httpConnectionManager().rdsName())) { - ServerRoutingConfig updatedRoutingConfig; - if (savedVirtualHosts == null) { - updatedRoutingConfig = ServerRoutingConfig.FAILING_ROUTING_CONFIG; - } else { - ImmutableMap updatedInterceptors = - generatePerRouteInterceptors( - filterChain.httpConnectionManager().httpFilterConfigs(), - savedVirtualHosts); - updatedRoutingConfig = ServerRoutingConfig.create(savedVirtualHosts, - updatedInterceptors); - } - logger.log(Level.FINEST, "Updating filter chain {0} rds routing config: {1}", - new Object[]{filterChain.name(), updatedRoutingConfig}); - savedRdsRoutingConfigRef.get(filterChain).set(updatedRoutingConfig); + HttpConnectionManager hcm = filterChain.httpConnectionManager(); + if (!resourceName.equals(hcm.rdsName())) { + continue; } + + ServerRoutingConfig updatedRoutingConfig; + if (savedVirtualHosts == null) { + updatedRoutingConfig = ServerRoutingConfig.FAILING_ROUTING_CONFIG; + } else { + HashMap chainFilters = activeFilters.get(filterChain.name()); + ImmutableMap interceptors = generatePerRouteInterceptors( + hcm.httpFilterConfigs(), savedVirtualHosts, chainFilters); + updatedRoutingConfig = ServerRoutingConfig.create(savedVirtualHosts, interceptors); + } + + logger.log(Level.FINEST, "Updating filter chain {0} rds routing config: {1}", + new Object[]{filterChain.name(), updatedRoutingConfig}); + savedRdsRoutingConfigRef.get(filterChain).set(updatedRoutingConfig); } } diff --git a/xds/src/main/java/io/grpc/xds/client/AllowedGrpcServices.java b/xds/src/main/java/io/grpc/xds/client/AllowedGrpcServices.java new file mode 100644 index 00000000000..e2d77689fca --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/client/AllowedGrpcServices.java @@ -0,0 +1,66 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.client; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableMap; +import io.grpc.CallCredentials; +import io.grpc.Internal; +import java.util.Map; +import java.util.Optional; + +/** + * Wrapper for allowed gRPC services keyed by target URI. + */ +@Internal +@AutoValue +public abstract class AllowedGrpcServices { + public abstract ImmutableMap services(); + + public static AllowedGrpcServices create(Map services) { + return new AutoValue_AllowedGrpcServices(ImmutableMap.copyOf(services)); + } + + public static AllowedGrpcServices empty() { + return create(ImmutableMap.of()); + } + + /** + * Represents an allowed gRPC service configuration with call credentials. + */ + @Internal + @AutoValue + public abstract static class AllowedGrpcService { + public abstract ConfiguredChannelCredentials configuredChannelCredentials(); + + public abstract Optional callCredentials(); + + public static Builder builder() { + return new AutoValue_AllowedGrpcServices_AllowedGrpcService.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder configuredChannelCredentials( + ConfiguredChannelCredentials credentials); + + public abstract Builder callCredentials(CallCredentials callCredentials); + + public abstract AllowedGrpcService build(); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/client/BackendMetricPropagation.java b/xds/src/main/java/io/grpc/xds/client/BackendMetricPropagation.java new file mode 100644 index 00000000000..f0e2c9484b4 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/client/BackendMetricPropagation.java @@ -0,0 +1,133 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.client; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableSet; +import io.grpc.Internal; +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * Represents the configuration for which ORCA metrics should be propagated from backend + * to LRS load reports, as defined in gRFC A85. + */ +@Internal +public final class BackendMetricPropagation { + + public final boolean propagateCpuUtilization; + public final boolean propagateMemUtilization; + public final boolean propagateApplicationUtilization; + + private final boolean propagateAllNamedMetrics; + private final ImmutableSet namedMetricKeys; + + private BackendMetricPropagation( + boolean propagateCpuUtilization, + boolean propagateMemUtilization, + boolean propagateApplicationUtilization, + boolean propagateAllNamedMetrics, + ImmutableSet namedMetricKeys) { + this.propagateCpuUtilization = propagateCpuUtilization; + this.propagateMemUtilization = propagateMemUtilization; + this.propagateApplicationUtilization = propagateApplicationUtilization; + this.propagateAllNamedMetrics = propagateAllNamedMetrics; + this.namedMetricKeys = checkNotNull(namedMetricKeys, "namedMetricKeys"); + } + + /** + * Creates a BackendMetricPropagation from a list of metric specifications. + * + * @param metricSpecs list of metric specification strings from CDS resource + * @return BackendMetricPropagation instance + */ + public static BackendMetricPropagation fromMetricSpecs( + @Nullable java.util.List metricSpecs) { + if (metricSpecs == null || metricSpecs.isEmpty()) { + return new BackendMetricPropagation(false, false, false, false, ImmutableSet.of()); + } + + boolean propagateCpuUtilization = false; + boolean propagateMemUtilization = false; + boolean propagateApplicationUtilization = false; + boolean propagateAllNamedMetrics = false; + ImmutableSet.Builder namedMetricKeysBuilder = ImmutableSet.builder(); + for (String spec : metricSpecs) { + if (spec == null) { + continue; + } + switch (spec) { + case "cpu_utilization": + propagateCpuUtilization = true; + break; + case "mem_utilization": + propagateMemUtilization = true; + break; + case "application_utilization": + propagateApplicationUtilization = true; + break; + case "named_metrics.*": + propagateAllNamedMetrics = true; + break; + default: + if (spec.startsWith("named_metrics.")) { + String metricKey = spec.substring("named_metrics.".length()); + if (!metricKey.isEmpty()) { + namedMetricKeysBuilder.add(metricKey); + } + } + } + } + + return new BackendMetricPropagation( + propagateCpuUtilization, + propagateMemUtilization, + propagateApplicationUtilization, + propagateAllNamedMetrics, + namedMetricKeysBuilder.build()); + } + + /** + * Returns whether the given named metric key should be propagated. + */ + public boolean shouldPropagateNamedMetric(String metricKey) { + return propagateAllNamedMetrics || namedMetricKeys.contains(metricKey); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BackendMetricPropagation that = (BackendMetricPropagation) o; + return propagateCpuUtilization == that.propagateCpuUtilization + && propagateMemUtilization == that.propagateMemUtilization + && propagateApplicationUtilization == that.propagateApplicationUtilization + && propagateAllNamedMetrics == that.propagateAllNamedMetrics + && Objects.equals(namedMetricKeys, that.namedMetricKeys); + } + + @Override + public int hashCode() { + return Objects.hash(propagateCpuUtilization, propagateMemUtilization, + propagateApplicationUtilization, propagateAllNamedMetrics, namedMetricKeys); + } +} \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/client/Bootstrapper.java b/xds/src/main/java/io/grpc/xds/client/Bootstrapper.java index fe0c0050b52..b8d6444e3b3 100644 --- a/xds/src/main/java/io/grpc/xds/client/Bootstrapper.java +++ b/xds/src/main/java/io/grpc/xds/client/Bootstrapper.java @@ -26,6 +26,7 @@ import io.grpc.xds.client.EnvoyProtoData.Node; import java.util.List; import java.util.Map; +import java.util.Optional; import javax.annotation.Nullable; /** @@ -61,16 +62,26 @@ public abstract static class ServerInfo { public abstract boolean ignoreResourceDeletion(); + public abstract boolean isTrustedXdsServer(); + + public abstract boolean resourceTimerIsTransientError(); + + public abstract boolean failOnDataErrors(); + @VisibleForTesting public static ServerInfo create(String target, @Nullable Object implSpecificConfig) { - return new AutoValue_Bootstrapper_ServerInfo(target, implSpecificConfig, false); + return new AutoValue_Bootstrapper_ServerInfo(target, implSpecificConfig, + false, false, false, false); } @VisibleForTesting public static ServerInfo create( - String target, Object implSpecificConfig, boolean ignoreResourceDeletion) { + String target, Object implSpecificConfig, + boolean ignoreResourceDeletion, boolean isTrustedXdsServer, + boolean resourceTimerIsTransientError, boolean failOnDataErrors) { return new AutoValue_Bootstrapper_ServerInfo(target, implSpecificConfig, - ignoreResourceDeletion); + ignoreResourceDeletion, isTrustedXdsServer, + resourceTimerIsTransientError, failOnDataErrors); } } @@ -195,11 +206,18 @@ public abstract static class BootstrapInfo { */ public abstract ImmutableMap authorities(); + /** + * Parsed configuration for implementation-specific extensions. + * Returns an opaque object containing the parsed configuration. + */ + public abstract Optional implSpecificObject(); + @VisibleForTesting public static Builder builder() { return new AutoValue_Bootstrapper_BootstrapInfo.Builder() .clientDefaultListenerResourceNameTemplate("%s") - .authorities(ImmutableMap.of()); + .authorities(ImmutableMap.of()) + .implSpecificObject(Optional.empty()); } @AutoValue.Builder @@ -221,7 +239,10 @@ public abstract Builder clientDefaultListenerResourceNameTemplate( public abstract Builder authorities(Map authorities); + public abstract Builder implSpecificObject(Optional implSpecificObject); + public abstract BootstrapInfo build(); } } + } diff --git a/xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java b/xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java index 7ef739c8048..3f4ea8eb5c6 100644 --- a/xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java +++ b/xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java @@ -34,6 +34,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import javax.annotation.Nullable; /** * A {@link Bootstrapper} implementation that reads xDS configurations from local file system. @@ -41,6 +43,11 @@ @Internal public abstract class BootstrapperImpl extends Bootstrapper { + public static final String GRPC_EXPERIMENTAL_XDS_FALLBACK = + "GRPC_EXPERIMENTAL_XDS_FALLBACK"; + public static final String GRPC_EXPERIMENTAL_XDS_DATA_ERROR_HANDLING = + "GRPC_EXPERIMENTAL_XDS_DATA_ERROR_HANDLING"; + // Client features. @VisibleForTesting public static final String CLIENT_FEATURE_DISABLE_OVERPROVISIONING = @@ -50,6 +57,17 @@ public abstract class BootstrapperImpl extends Bootstrapper { // Server features. private static final String SERVER_FEATURE_IGNORE_RESOURCE_DELETION = "ignore_resource_deletion"; + private static final String SERVER_FEATURE_TRUSTED_XDS_SERVER = "trusted_xds_server"; + private static final String + SERVER_FEATURE_RESOURCE_TIMER_IS_TRANSIENT_ERROR = "resource_timer_is_transient_error"; + private static final String SERVER_FEATURE_FAIL_ON_DATA_ERRORS = "fail_on_data_errors"; + + @VisibleForTesting + static boolean enableXdsFallback = GrpcUtil.getFlag(GRPC_EXPERIMENTAL_XDS_FALLBACK, true); + + @VisibleForTesting + public static boolean xdsDataErrorHandlingEnabled + = GrpcUtil.getFlag(GRPC_EXPERIMENTAL_XDS_DATA_ERROR_HANDLING, false); protected final XdsLogger logger; @@ -64,6 +82,7 @@ protected BootstrapperImpl() { protected abstract Object getImplSpecificConfig(Map serverConfig, String serverUri) throws XdsInitializationException; + /** * Reads and parses bootstrap config. The config is expected to be in JSON format. */ @@ -102,6 +121,9 @@ protected BootstrapInfo.Builder bootstrapBuilder(Map rawData) throw new XdsInitializationException("Invalid bootstrap: 'xds_servers' does not exist."); } List servers = parseServerInfos(rawServerConfigs, logger); + if (servers.size() > 1 && !enableXdsFallback) { + servers = ImmutableList.of(servers.get(0)); + } builder.servers(servers); Node.Builder nodeBuilder = Node.newBuilder(); @@ -208,6 +230,9 @@ protected BootstrapInfo.Builder bootstrapBuilder(Map rawData) if (rawAuthorityServers == null || rawAuthorityServers.isEmpty()) { authorityServers = servers; } else { + if (rawAuthorityServers.size() > 1 && !enableXdsFallback) { + rawAuthorityServers = ImmutableList.of(rawAuthorityServers.get(0)); + } authorityServers = parseServerInfos(rawAuthorityServers, logger); } authorityInfoMapBuilder.put( @@ -216,9 +241,18 @@ protected BootstrapInfo.Builder bootstrapBuilder(Map rawData) builder.authorities(authorityInfoMapBuilder.buildOrThrow()); } + Map rawAllowedGrpcServices = JsonUtil.getObject(rawData, "allowed_grpc_services"); + builder.implSpecificObject(parseImplSpecificObject(rawAllowedGrpcServices)); + return builder; } + protected Optional parseImplSpecificObject( + @Nullable Map rawAllowedGrpcServices) + throws XdsInitializationException { + return Optional.empty(); + } + private List parseServerInfos(List rawServerConfigs, XdsLogger logger) throws XdsInitializationException { logger.log(XdsLogLevel.INFO, "Configured with {0} xDS servers", rawServerConfigs.size()); @@ -233,14 +267,27 @@ private List parseServerInfos(List rawServerConfigs, XdsLogger lo Object implSpecificConfig = getImplSpecificConfig(serverConfig, serverUri); + boolean resourceTimerIsTransientError = false; boolean ignoreResourceDeletion = false; - List serverFeatures = JsonUtil.getListOfStrings(serverConfig, "server_features"); + boolean failOnDataErrors = false; + // "For forward compatibility reasons, the client will ignore any entry in the list that it + // does not understand, regardless of type." + List serverFeatures = JsonUtil.getList(serverConfig, "server_features"); if (serverFeatures != null) { logger.log(XdsLogLevel.INFO, "Server features: {0}", serverFeatures); - ignoreResourceDeletion = serverFeatures.contains(SERVER_FEATURE_IGNORE_RESOURCE_DELETION); + if (serverFeatures.contains(SERVER_FEATURE_IGNORE_RESOURCE_DELETION)) { + ignoreResourceDeletion = true; + } + resourceTimerIsTransientError = xdsDataErrorHandlingEnabled + && serverFeatures.contains(SERVER_FEATURE_RESOURCE_TIMER_IS_TRANSIENT_ERROR); + failOnDataErrors = xdsDataErrorHandlingEnabled + && serverFeatures.contains(SERVER_FEATURE_FAIL_ON_DATA_ERRORS); } servers.add( - ServerInfo.create(serverUri, implSpecificConfig, ignoreResourceDeletion)); + ServerInfo.create(serverUri, implSpecificConfig, ignoreResourceDeletion, + serverFeatures != null + && serverFeatures.contains(SERVER_FEATURE_TRUSTED_XDS_SERVER), + resourceTimerIsTransientError, failOnDataErrors)); } return servers.build(); } diff --git a/xds/src/main/java/io/grpc/xds/client/ConfiguredChannelCredentials.java b/xds/src/main/java/io/grpc/xds/client/ConfiguredChannelCredentials.java new file mode 100644 index 00000000000..c6b9d774b4d --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/client/ConfiguredChannelCredentials.java @@ -0,0 +1,48 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.client; + +import com.google.auto.value.AutoValue; +import io.grpc.ChannelCredentials; +import io.grpc.Internal; + +/** + * Composition of {@link ChannelCredentials} and {@link ChannelCredsConfig}. + */ +@Internal +@AutoValue +public abstract class ConfiguredChannelCredentials { + public abstract ChannelCredentials channelCredentials(); + + public abstract ChannelCredsConfig channelCredsConfig(); + + public static ConfiguredChannelCredentials create(ChannelCredentials creds, + ChannelCredsConfig config) { + return new AutoValue_ConfiguredChannelCredentials(creds, config); + } + + /** + * Configuration for channel credentials. + */ + @Internal + public interface ChannelCredsConfig { + /** + * Returns the type of the credentials. + */ + String type(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/client/ControlPlaneClient.java b/xds/src/main/java/io/grpc/xds/client/ControlPlaneClient.java index 3074d1120ad..59f439d3687 100644 --- a/xds/src/main/java/io/grpc/xds/client/ControlPlaneClient.java +++ b/xds/src/main/java/io/grpc/xds/client/ControlPlaneClient.java @@ -16,7 +16,6 @@ package io.grpc.xds.client; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; @@ -40,7 +39,6 @@ import io.grpc.xds.client.XdsClient.ResourceStore; import io.grpc.xds.client.XdsClient.XdsResponseHandler; import io.grpc.xds.client.XdsLogger.XdsLogLevel; -import io.grpc.xds.client.XdsTransportFactory.EventHandler; import io.grpc.xds.client.XdsTransportFactory.StreamingCall; import io.grpc.xds.client.XdsTransportFactory.XdsTransport; import java.util.Collection; @@ -60,7 +58,6 @@ */ final class ControlPlaneClient { - public static final String CLOSED_BY_SERVER = "Closed by server"; private final SynchronizationContext syncContext; private final InternalLogId logId; private final XdsLogger logger; @@ -72,7 +69,6 @@ final class ControlPlaneClient { private final BackoffPolicy.Provider backoffPolicyProvider; private final Stopwatch stopwatch; private final Node bootstrapNode; - private final XdsClient xdsClient; // Last successfully applied version_info for each resource type. Starts with empty string. // A version_info is used to update management server with client's most recent knowledge of @@ -80,13 +76,15 @@ final class ControlPlaneClient { private final Map, String> versions = new HashMap<>(); private boolean shutdown; + private boolean inError; + @Nullable private AdsStream adsStream; @Nullable private BackoffPolicy retryBackoffPolicy; @Nullable private ScheduledHandle rpcRetryTimer; - private MessagePrettyPrinter messagePrinter; + private final MessagePrettyPrinter messagePrinter; /** An entity that manages ADS RPCs over a single channel. */ ControlPlaneClient( @@ -100,7 +98,6 @@ final class ControlPlaneClient { SynchronizationContext syncContext, BackoffPolicy.Provider backoffPolicyProvider, Supplier stopwatchSupplier, - XdsClient xdsClient, MessagePrettyPrinter messagePrinter) { this.serverInfo = checkNotNull(serverInfo, "serverInfo"); this.xdsTransport = checkNotNull(xdsTransport, "xdsTransport"); @@ -110,10 +107,9 @@ final class ControlPlaneClient { this.timeService = checkNotNull(timeService, "timeService"); this.syncContext = checkNotNull(syncContext, "syncContext"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); - this.xdsClient = checkNotNull(xdsClient, "xdsClient"); this.messagePrinter = checkNotNull(messagePrinter, "messagePrinter"); stopwatch = checkNotNull(stopwatchSupplier, "stopwatchSupplier").get(); - logId = InternalLogId.allocate("xds-client", serverInfo.target()); + logId = InternalLogId.allocate("xds-cp-client", serverInfo.target()); logger = XdsLogger.withLogId(logId); logger.log(XdsLogLevel.INFO, "Created"); } @@ -140,26 +136,41 @@ public String toString() { return logId.toString(); } + public ServerInfo getServerInfo() { + return serverInfo; + } + /** * Updates the resource subscription for the given resource type. */ // Must be synchronized. void adjustResourceSubscription(XdsResourceType resourceType) { - if (isInBackoff()) { + if (rpcRetryTimer != null && rpcRetryTimer.isPending()) { return; } if (adsStream == null) { startRpcStream(); + // when the stream becomes ready, it will send the discovery requests + return; + } + + // We will do the rest of the method as part of the readyHandler when the stream is ready. + if (!isConnected()) { + return; } + Collection resources = resourceStore.getSubscribedResources(serverInfo, resourceType); if (resources == null) { resources = Collections.emptyList(); } adsStream.sendDiscoveryRequest(resourceType, resources); + resourceStore.startMissingResourceTimers(resources, resourceType); + if (resources.isEmpty()) { - // The resource type no longer has subscribing resources; clean up references to it + // The resource type no longer has subscribing resources; clean up references to it, except + // for nonces. If the resource type becomes used again the control plane can ignore requests + // for old/missing nonces. Old type's nonces are dropped when the ADS stream is restarted. versions.remove(resourceType); - adsStream.respNonces.remove(resourceType); } } @@ -195,35 +206,42 @@ void nackResponse(XdsResourceType type, String nonce, String errorDetail) { adsStream.sendDiscoveryRequest(type, versionInfo, resources, nonce, errorDetail); } - /** - * Returns {@code true} if the resource discovery is currently in backoff. - */ // Must be synchronized. - boolean isInBackoff() { - return rpcRetryTimer != null && rpcRetryTimer.isPending(); + boolean isReady() { + return adsStream != null && adsStream.call != null + && adsStream.call.isReady() && !adsStream.closed; } - // Must be synchronized. - boolean isReady() { - return adsStream != null && adsStream.call != null && adsStream.call.isReady(); + boolean isConnected() { + return adsStream != null && adsStream.sentInitialRequest; } /** - * Starts a timer for each requested resource that hasn't been responded to and - * has been waiting for the channel to get ready. + * Used for identifying whether or not when getting a control plane for authority that this + * control plane should be skipped over if there is a fallback. + * + *

Also used by metric to consider this control plane to not be "active". + * + *

A ControlPlaneClient is considered to be in error during the time from when an + * {@link AdsStream} closed without having received a response to the time an AdsStream does + * receive a response. */ - // Must be synchronized. - void readyHandler() { - if (!isReady()) { - return; - } + boolean isInError() { + return inError; + } - if (isInBackoff()) { - rpcRetryTimer.cancel(); - rpcRetryTimer = null; - } - xdsClient.startSubscriberTimersIfNeeded(serverInfo); + /** + * Cleans up outstanding rpcRetryTimer if present, since we are communicating. + * If we haven't sent the initial discovery request for this RPC stream, we will delegate to + * xdsResponseHandler (in practice XdsClientImpl) to do any initialization for a new active + * stream such as starting timers. We then send the initial discovery request. + */ + // Must be synchronized. + void readyHandler(boolean shouldSendInitialRequest) { + if (shouldSendInitialRequest) { + sendDiscoveryRequests(); + } } /** @@ -233,28 +251,51 @@ void readyHandler() { // Must be synchronized. private void startRpcStream() { checkState(adsStream == null, "Previous adsStream has not been cleared yet"); + + if (rpcRetryTimer != null) { + rpcRetryTimer.cancel(); + rpcRetryTimer = null; + } + adsStream = new AdsStream(); + adsStream.start(); logger.log(XdsLogLevel.INFO, "ADS stream started"); stopwatch.reset().start(); } + void sendDiscoveryRequests() { + if (rpcRetryTimer != null && rpcRetryTimer.isPending()) { + return; + } + + if (adsStream == null) { + startRpcStream(); + // when the stream becomes ready, it will send the discovery requests + return; + } + + if (isConnected()) { + Set> subscribedResourceTypes = + new HashSet<>(resourceStore.getSubscribedResourceTypesWithTypeUrl().values()); + + for (XdsResourceType type : subscribedResourceTypes) { + adjustResourceSubscription(type); + } + } + } + @VisibleForTesting public final class RpcRetryTask implements Runnable { @Override public void run() { + logger.log(XdsLogLevel.DEBUG, "Retry timeout. Restart ADS stream {0}", logId); if (shutdown) { return; } + startRpcStream(); - Set> subscribedResourceTypes = - new HashSet<>(resourceStore.getSubscribedResourceTypesWithTypeUrl().values()); - for (XdsResourceType type : subscribedResourceTypes) { - Collection resources = resourceStore.getSubscribedResources(serverInfo, type); - if (resources != null) { - adsStream.sendDiscoveryRequest(type, resources); - } - } - xdsResponseHandler.handleStreamRestarted(serverInfo); + + // handling CPC management is triggered in readyHandler } } @@ -264,16 +305,20 @@ XdsResourceType fromTypeUrl(String typeUrl) { return resourceStore.getSubscribedResourceTypesWithTypeUrl().get(typeUrl); } - private class AdsStream implements EventHandler { + private class AdsStream implements XdsTransportFactory.EventHandler { private boolean responseReceived; + private boolean sentInitialRequest; private boolean closed; - // Response nonce for the most recently received discovery responses of each resource type. + // Response nonce for the most recently received discovery responses of each resource type URL. // Client initiated requests start response nonce with empty string. // Nonce in each response is echoed back in the following ACK/NACK request. It is // used for management server to identify which response the client is ACKing/NACking. // To avoid confusion, client-initiated requests will always use the nonce in - // most recently received responses of each resource type. - private final Map, String> respNonces = new HashMap<>(); + // most recently received responses of each resource type. Nonces are never deleted from the + // map; nonces are only discarded once the stream closes because xds_protocol says "the + // management server should not send a DiscoveryResponse for any DiscoveryRequest that has a + // stale nonce." + private final Map respNonces = new HashMap<>(); private final StreamingCall call; private final MethodDescriptor methodDescriptor = AggregatedDiscoveryServiceGrpc.getStreamAggregatedResourcesMethod(); @@ -281,6 +326,9 @@ private class AdsStream implements EventHandler { private AdsStream() { this.call = xdsTransport.createStreamingCall(methodDescriptor.getFullMethodName(), methodDescriptor.getRequestMarshaller(), methodDescriptor.getResponseMarshaller()); + } + + void start() { call.start(this); } @@ -321,12 +369,24 @@ void sendDiscoveryRequest(XdsResourceType type, String versionInfo, final void sendDiscoveryRequest(XdsResourceType type, Collection resources) { logger.log(XdsLogLevel.INFO, "Sending {0} request for resources: {1}", type, resources); sendDiscoveryRequest(type, versions.getOrDefault(type, ""), resources, - respNonces.getOrDefault(type, ""), null); + respNonces.getOrDefault(type.typeUrl(), ""), null); } @Override public void onReady() { - syncContext.execute(ControlPlaneClient.this::readyHandler); + syncContext.execute(() -> { + if (!isReady()) { + logger.log(XdsLogLevel.DEBUG, + "ADS stream ready handler called, but not ready {0}", logId); + return; + } + + logger.log(XdsLogLevel.DEBUG, "ADS stream ready {0}", logId); + + boolean hadSentInitialRequest = sentInitialRequest; + sentInitialRequest = true; + readyHandler(!hadSentInitialRequest); + }); } @Override @@ -334,6 +394,14 @@ public void onRecvMessage(DiscoveryResponse response) { syncContext.execute(new Runnable() { @Override public void run() { + if (closed) { + return; + } + boolean isFirstResponse = !responseReceived; + responseReceived = true; + inError = false; + respNonces.put(response.getTypeUrl(), response.getNonce()); + XdsResourceType type = fromTypeUrl(response.getTypeUrl()); if (logger.isLoggable(XdsLogLevel.DEBUG)) { logger.log( @@ -350,7 +418,7 @@ public void run() { return; } handleRpcResponse(type, response.getVersionInfo(), response.getResourcesList(), - response.getNonce()); + response.getNonce(), isFirstResponse); } }); } @@ -358,30 +426,22 @@ public void run() { @Override public void onStatusReceived(final Status status) { syncContext.execute(() -> { - if (status.isOk()) { - handleRpcStreamClosed(Status.UNAVAILABLE.withDescription(CLOSED_BY_SERVER)); - } else { - handleRpcStreamClosed(status); - } + handleRpcStreamClosed(status); }); } final void handleRpcResponse(XdsResourceType type, String versionInfo, List resources, - String nonce) { + String nonce, boolean isFirstResponse) { checkNotNull(type, "type"); - if (closed) { - return; - } - responseReceived = true; - respNonces.put(type, nonce); + ProcessingTracker processingTracker = new ProcessingTracker( () -> call.startRecvMessage(), syncContext); xdsResponseHandler.handleResourceResponse(type, serverInfo, versionInfo, resources, nonce, - processingTracker); + isFirstResponse, processingTracker); processingTracker.onComplete(); } - private void handleRpcStreamClosed(Status error) { + private void handleRpcStreamClosed(Status status) { if (closed) { return; } @@ -390,27 +450,47 @@ private void handleRpcStreamClosed(Status error) { // Reset the backoff sequence if had received a response, or backoff sequence // has never been initialized. retryBackoffPolicy = backoffPolicyProvider.get(); + stopwatch.reset(); + } + + Status newStatus = status; + if (responseReceived) { + // A closed ADS stream after a successful response is not considered an error. Servers may + // close streams for various reasons during normal operation, such as load balancing or + // underlying connection hitting its max connection age limit (see gRFC A9). + if (!status.isOk()) { + newStatus = Status.OK; + logger.log(XdsLogLevel.DEBUG, "ADS stream closed with error {0}: {1}. However, a " + + "response was received, so this will not be treated as an error. Cause: {2}", + status.getCode(), status.getDescription(), status.getCause()); + } else { + logger.log(XdsLogLevel.DEBUG, + "ADS stream closed by server after a response was received"); + } + } else { + // If the ADS stream is closed without ever having received a response from the server, then + // the XdsClient should consider that a connectivity error (see gRFC A57). + inError = true; + if (status.isOk()) { + newStatus = Status.UNAVAILABLE.withDescription( + "ADS stream closed with OK before receiving a response"); + } + logger.log( + XdsLogLevel.ERROR, "ADS stream failed with status {0}: {1}. Cause: {2}", + newStatus.getCode(), newStatus.getDescription(), newStatus.getCause()); } + + close(newStatus.asException()); + // FakeClock in tests isn't thread-safe. Schedule the retry timer before notifying callbacks // to avoid TSAN races, since tests may wait until callbacks are called but then would run // concurrently with the stopwatch and schedule. long elapsed = stopwatch.elapsed(TimeUnit.NANOSECONDS); long delayNanos = Math.max(0, retryBackoffPolicy.nextBackoffNanos() - elapsed); - rpcRetryTimer = syncContext.schedule( - new RpcRetryTask(), delayNanos, TimeUnit.NANOSECONDS, timeService); - - checkArgument(!error.isOk(), "unexpected OK status"); - String errorMsg = error.getDescription() != null - && error.getDescription().equals(CLOSED_BY_SERVER) - ? "ADS stream closed with status {0}: {1}. Cause: {2}" - : "ADS stream failed with status {0}: {1}. Cause: {2}"; - logger.log( - XdsLogLevel.ERROR, errorMsg, error.getCode(), error.getDescription(), error.getCause()); - closed = true; - xdsResponseHandler.handleStreamClosed(error); - cleanUp(); + rpcRetryTimer = + syncContext.schedule(new RpcRetryTask(), delayNanos, TimeUnit.NANOSECONDS, timeService); - logger.log(XdsLogLevel.INFO, "Retry ADS stream in {0} ns", delayNanos); + xdsResponseHandler.handleStreamClosed(newStatus, !responseReceived); } private void close(Exception error) { @@ -428,4 +508,55 @@ private void cleanUp() { } } } + + @VisibleForTesting + static class FailingXdsTransport implements XdsTransport { + Status error; + + public FailingXdsTransport(Status error) { + this.error = error; + } + + @Override + public StreamingCall + createStreamingCall(String fullMethodName, + MethodDescriptor.Marshaller reqMarshaller, + MethodDescriptor.Marshaller respMarshaller) { + return new FailingXdsStreamingCall<>(); + } + + @Override + public void shutdown() { + // no-op + } + + private class FailingXdsStreamingCall implements StreamingCall { + + @Override + public void start(XdsTransportFactory.EventHandler eventHandler) { + eventHandler.onStatusReceived(error); + } + + @Override + public void sendMessage(ReqT message) { + // no-op + } + + @Override + public void startRecvMessage() { + // no-op + } + + @Override + public void sendError(Exception e) { + // no-op + } + + @Override + public boolean isReady() { + return false; + } + } + } + } diff --git a/xds/src/main/java/io/grpc/xds/client/LoadStatsManager2.java b/xds/src/main/java/io/grpc/xds/client/LoadStatsManager2.java index be9d3587d14..cd858dccd99 100644 --- a/xds/src/main/java/io/grpc/xds/client/LoadStatsManager2.java +++ b/xds/src/main/java/io/grpc/xds/client/LoadStatsManager2.java @@ -25,6 +25,7 @@ import com.google.common.collect.Sets; import io.grpc.Internal; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.xds.client.Stats.BackendLoadMetricStats; import io.grpc.xds.client.Stats.ClusterStats; import io.grpc.xds.client.Stats.DroppedRequests; @@ -57,6 +58,8 @@ public final class LoadStatsManager2 { private final Map>>> allLoadStats = new HashMap<>(); private final Supplier stopwatchSupplier; + public static boolean isEnabledOrcaLrsPropagation = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_ORCA_LRS_PROPAGATION", false); @VisibleForTesting public LoadStatsManager2(Supplier stopwatchSupplier) { @@ -98,13 +101,20 @@ private synchronized void releaseClusterDropCounter( /** * Gets or creates the stats object for recording loads for the specified locality (in the - * specified cluster with edsServiceName). The returned object is reference counted and the - * caller should use {@link ClusterLocalityStats#release} to release its hard reference + * specified cluster with edsServiceName) with the specified backend metric propagation + * configuration. The returned object is reference counted and the caller should + * use {@link ClusterLocalityStats#release} to release its hard reference * when it is safe to discard the future stats for the locality. */ @VisibleForTesting public synchronized ClusterLocalityStats getClusterLocalityStats( String cluster, @Nullable String edsServiceName, Locality locality) { + return getClusterLocalityStats(cluster, edsServiceName, locality, null); + } + + public synchronized ClusterLocalityStats getClusterLocalityStats( + String cluster, @Nullable String edsServiceName, Locality locality, + @Nullable BackendMetricPropagation backendMetricPropagation) { if (!allLoadStats.containsKey(cluster)) { allLoadStats.put( cluster, @@ -121,8 +131,8 @@ public synchronized ClusterLocalityStats getClusterLocalityStats( if (!localityStats.containsKey(locality)) { localityStats.put( locality, - ReferenceCounted.wrap(new ClusterLocalityStats( - cluster, edsServiceName, locality, stopwatchSupplier.get()))); + ReferenceCounted.wrap(new ClusterLocalityStats(cluster, edsServiceName, + locality, stopwatchSupplier.get(), backendMetricPropagation))); } ReferenceCounted ref = localityStats.get(locality); ref.retain(); @@ -325,6 +335,8 @@ public final class ClusterLocalityStats { private final String edsServiceName; private final Locality locality; private final Stopwatch stopwatch; + @Nullable + private final BackendMetricPropagation backendMetricPropagation; private final AtomicLong callsInProgress = new AtomicLong(); private final AtomicLong callsSucceeded = new AtomicLong(); private final AtomicLong callsFailed = new AtomicLong(); @@ -333,11 +345,12 @@ public final class ClusterLocalityStats { private ClusterLocalityStats( String clusterName, @Nullable String edsServiceName, Locality locality, - Stopwatch stopwatch) { + Stopwatch stopwatch, BackendMetricPropagation backendMetricPropagation) { this.clusterName = checkNotNull(clusterName, "clusterName"); this.edsServiceName = edsServiceName; this.locality = checkNotNull(locality, "locality"); this.stopwatch = checkNotNull(stopwatch, "stopwatch"); + this.backendMetricPropagation = backendMetricPropagation; stopwatch.reset().start(); } @@ -367,17 +380,51 @@ public void recordCallFinished(Status status) { * requests counter of 1 and the {@code value} if the key is not present in the map. Otherwise, * increments the finished requests counter and adds the {@code value} to the existing * {@link BackendLoadMetricStats}. + * Metrics are filtered based on the backend metric propagation configuration if configured. */ public synchronized void recordBackendLoadMetricStats(Map namedMetrics) { + if (!isEnabledOrcaLrsPropagation) { + namedMetrics.forEach((name, value) -> updateLoadMetricStats(name, value)); + return; + } + namedMetrics.forEach((name, value) -> { - if (!loadMetricStatsMap.containsKey(name)) { - loadMetricStatsMap.put(name, new BackendLoadMetricStats(1, value)); - } else { - loadMetricStatsMap.get(name).addMetricValueAndIncrementRequestsFinished(value); + if (backendMetricPropagation.shouldPropagateNamedMetric(name)) { + updateLoadMetricStats("named_metrics." + name, value); } }); } + private void updateLoadMetricStats(String metricName, double value) { + if (!loadMetricStatsMap.containsKey(metricName)) { + loadMetricStatsMap.put(metricName, new BackendLoadMetricStats(1, value)); + } else { + loadMetricStatsMap.get(metricName).addMetricValueAndIncrementRequestsFinished(value); + } + } + + /** + * Records top-level ORCA metrics (CPU, memory, application utilization) for per-call load + * reporting. Metrics are filtered based on the backend metric propagation configuration + * if configured. + * + * @param cpuUtilization CPU utilization metric value + * @param memUtilization Memory utilization metric value + * @param applicationUtilization Application utilization metric value + */ + public synchronized void recordTopLevelMetrics(double cpuUtilization, double memUtilization, + double applicationUtilization) { + if (backendMetricPropagation.propagateCpuUtilization && cpuUtilization > 0) { + updateLoadMetricStats("cpu_utilization", cpuUtilization); + } + if (backendMetricPropagation.propagateMemUtilization && memUtilization > 0) { + updateLoadMetricStats("mem_utilization", memUtilization); + } + if (backendMetricPropagation.propagateApplicationUtilization && applicationUtilization > 0) { + updateLoadMetricStats("application_utilization", applicationUtilization); + } + } + /** * Release the hard reference for this stats object (previously obtained via {@link * LoadStatsManager2#getClusterLocalityStats}). The object may still be diff --git a/xds/src/main/java/io/grpc/xds/client/XdsClient.java b/xds/src/main/java/io/grpc/xds/client/XdsClient.java index fc7e1777384..982fb6651a9 100644 --- a/xds/src/main/java/io/grpc/xds/client/XdsClient.java +++ b/xds/src/main/java/io/grpc/xds/client/XdsClient.java @@ -27,6 +27,7 @@ import com.google.protobuf.Any; import io.grpc.ExperimentalApi; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.xds.client.Bootstrapper.ServerInfo; import java.net.URI; import java.net.URISyntaxException; @@ -36,6 +37,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.Executor; +import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.Nullable; @@ -117,34 +119,50 @@ public static String percentEncodePath(String input) { return Joiner.on('/').join(encodedSegs); } + /** + * Returns the authority from the resource name. + */ + public static String getAuthorityFromResourceName(String resourceNames) { + String authority; + if (resourceNames.startsWith(XDSTP_SCHEME)) { + URI uri = URI.create(resourceNames); + authority = uri.getAuthority(); + if (authority == null) { + authority = ""; + } + } else { + authority = null; + } + return authority; + } + public interface ResourceUpdate {} /** * Watcher interface for a single requested xDS resource. + * + *

Note that we expect that the implementer to: + * - Comply with the guarantee to not generate certain statuses by the library: + * https://grpc.github.io/grpc/core/md_doc_statuscodes.html. If the code needs to be + * propagated to the channel, override it with {@link io.grpc.Status.Code#UNAVAILABLE}. + * - Keep {@link Status} description in one form or another, as it contains valuable debugging + * information. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10862") public interface ResourceWatcher { /** - * Called when the resource discovery RPC encounters some transient error. - * - *

Note that we expect that the implementer to: - * - Comply with the guarantee to not generate certain statuses by the library: - * https://grpc.github.io/grpc/core/md_doc_statuscodes.html. If the code needs to be - * propagated to the channel, override it with {@link io.grpc.Status.Code#UNAVAILABLE}. - * - Keep {@link Status} description in one form or another, as it contains valuable debugging - * information. + * Called to deliver a resource update or an error. If an error is passed after a valid + * resource has been delivered, the watcher should stop using the previously delivered + * resource. */ - void onError(Status error); + void onResourceChanged(StatusOr update); /** - * Called when the requested resource is not available. - * - * @param resourceName name of the resource requested in discovery request. - */ - void onResourceDoesNotExist(String resourceName); - - void onChanged(T update); + * Called to deliver a transient error that should not affect the watcher's use of any + * previously received resource. + * */ + void onAmbientError(Status error); } /** @@ -154,44 +172,50 @@ public static final class ResourceMetadata { private final String version; private final ResourceMetadataStatus status; private final long updateTimeNanos; + private final boolean cached; @Nullable private final Any rawResource; @Nullable private final UpdateFailureState errorState; private ResourceMetadata( - ResourceMetadataStatus status, String version, long updateTimeNanos, + ResourceMetadataStatus status, String version, long updateTimeNanos, boolean cached, @Nullable Any rawResource, @Nullable UpdateFailureState errorState) { this.status = checkNotNull(status, "status"); this.version = checkNotNull(version, "version"); this.updateTimeNanos = updateTimeNanos; + this.cached = cached; this.rawResource = rawResource; this.errorState = errorState; } - static ResourceMetadata newResourceMetadataUnknown() { - return new ResourceMetadata(ResourceMetadataStatus.UNKNOWN, "", 0, null, null); + public static ResourceMetadata newResourceMetadataUnknown() { + return new ResourceMetadata(ResourceMetadataStatus.UNKNOWN, "", 0, false,null, null); + } + + public static ResourceMetadata newResourceMetadataRequested() { + return new ResourceMetadata(ResourceMetadataStatus.REQUESTED, "", 0, false, null, null); } - static ResourceMetadata newResourceMetadataRequested() { - return new ResourceMetadata(ResourceMetadataStatus.REQUESTED, "", 0, null, null); + public static ResourceMetadata newResourceMetadataDoesNotExist() { + return new ResourceMetadata(ResourceMetadataStatus.DOES_NOT_EXIST, "", 0, false, null, null); } - static ResourceMetadata newResourceMetadataDoesNotExist() { - return new ResourceMetadata(ResourceMetadataStatus.DOES_NOT_EXIST, "", 0, null, null); + public static ResourceMetadata newResourceMetadataTimeout() { + return new ResourceMetadata(ResourceMetadataStatus.TIMEOUT, "", 0, false, null, null); } public static ResourceMetadata newResourceMetadataAcked( Any rawResource, String version, long updateTimeNanos) { checkNotNull(rawResource, "rawResource"); return new ResourceMetadata( - ResourceMetadataStatus.ACKED, version, updateTimeNanos, rawResource, null); + ResourceMetadataStatus.ACKED, version, updateTimeNanos, true, rawResource, null); } - static ResourceMetadata newResourceMetadataNacked( + public static ResourceMetadata newResourceMetadataNacked( ResourceMetadata metadata, String failedVersion, long failedUpdateTime, - String failedDetails) { + String failedDetails, boolean cached) { checkNotNull(metadata, "metadata"); return new ResourceMetadata(ResourceMetadataStatus.NACKED, - metadata.getVersion(), metadata.getUpdateTimeNanos(), metadata.getRawResource(), + metadata.getVersion(), metadata.getUpdateTimeNanos(), cached, metadata.getRawResource(), new UpdateFailureState(failedVersion, failedUpdateTime, failedDetails)); } @@ -210,6 +234,11 @@ public long getUpdateTimeNanos() { return updateTimeNanos; } + /** Returns whether the resource was cached. */ + public boolean isCached() { + return cached; + } + /** The last successfully updated xDS resource as it was returned by the server. */ @Nullable public Any getRawResource() { @@ -231,7 +260,7 @@ public UpdateFailureState getErrorState() { * config_dump.proto */ public enum ResourceMetadataStatus { - UNKNOWN, REQUESTED, DOES_NOT_EXIST, ACKED, NACKED + UNKNOWN, REQUESTED, DOES_NOT_EXIST, ACKED, NACKED, TIMEOUT } /** @@ -298,14 +327,6 @@ public Object getSecurityConfig() { throw new UnsupportedOperationException(); } - /** - * For all subscriber's for the specified server, if the resource hasn't yet been - * resolved then start a timer for it. - */ - protected void startSubscriberTimersIfNeeded(ServerInfo serverInfo) { - throw new UnsupportedOperationException(); - } - /** * Returns a {@link ListenableFuture} to the snapshot of the subscribed resources as * they are at the moment of the call. @@ -367,6 +388,23 @@ public LoadStatsManager2.ClusterDropStats addClusterDropStats( public LoadStatsManager2.ClusterLocalityStats addClusterLocalityStats( Bootstrapper.ServerInfo serverInfo, String clusterName, @Nullable String edsServiceName, Locality locality) { + return addClusterLocalityStats(serverInfo, clusterName, edsServiceName, locality, null); + } + + /** + * Adds load stats for the specified locality (in the specified cluster with edsServiceName) by + * using the returned object to record RPCs. Load stats recorded with the returned object will + * be reported to the load reporting server. The returned object is reference counted and the + * caller should use {@link LoadStatsManager2.ClusterLocalityStats#release} to release its + * hard reference when it is safe to stop reporting RPC loads for the specified locality + * in the future. + * + * @param backendMetricPropagation Configuration for which backend metrics should be propagated + * to LRS load reports. If null, all metrics will be propagated (legacy behavior). + */ + public LoadStatsManager2.ClusterLocalityStats addClusterLocalityStats( + Bootstrapper.ServerInfo serverInfo, String clusterName, @Nullable String edsServiceName, + Locality locality, @Nullable BackendMetricPropagation backendMetricPropagation) { throw new UnsupportedOperationException(); } @@ -378,6 +416,23 @@ public Map getServerLrsClientMap() { throw new UnsupportedOperationException(); } + /** Callback used to report a gauge metric value for server connections. */ + public interface ServerConnectionCallback { + void reportServerConnectionGauge(boolean isConnected, String xdsServer); + } + + /** + * Reports whether xDS client has a "working" ADS stream to xDS server. The definition of a + * working stream is defined in gRFC A78. + * + * @see + * A78-grpc-metrics-wrr-pf-xds.md + */ + public Future reportServerConnections(ServerConnectionCallback callback) { + throw new UnsupportedOperationException(); + } + static final class ProcessingTracker { private final AtomicInteger pendingTask = new AtomicInteger(1); private final Executor executor; @@ -403,30 +458,39 @@ interface XdsResponseHandler { /** Called when a xds response is received. */ void handleResourceResponse( XdsResourceType resourceType, ServerInfo serverInfo, String versionInfo, - List resources, String nonce, ProcessingTracker processingTracker); + List resources, String nonce, boolean isFirstResponse, + ProcessingTracker processingTracker); /** Called when the ADS stream is closed passively. */ // Must be synchronized. - void handleStreamClosed(Status error); - - /** Called when the ADS stream has been recreated. */ - // Must be synchronized. - void handleStreamRestarted(ServerInfo serverInfo); + void handleStreamClosed(Status error, boolean shouldTryFallback); } - public interface ResourceStore { + interface ResourceStore { + /** - * Returns the collection of resources currently subscribing to or {@code null} if not - * subscribing to any resources for the given type. + * Returns the collection of resources currently subscribed to which have an authority matching + * one of those for which the ControlPlaneClient associated with the specified ServerInfo is + * the active one, or {@code null} if no such resources are currently subscribed to. * *

Note an empty collection indicates subscribing to resources of the given type with * wildcard mode. + * + * @param serverInfo the xds server to get the resources from + * @param type the type of the resources that should be retrieved */ // Must be synchronized. @Nullable - Collection getSubscribedResources(ServerInfo serverInfo, - XdsResourceType type); + Collection getSubscribedResources( + ServerInfo serverInfo, XdsResourceType type); Map> getSubscribedResourceTypesWithTypeUrl(); + + /** + * For any of the subscribers to one of the specified resources, if there isn't a result or + * an existing timer for the resource, start a timer for the resource. + */ + void startMissingResourceTimers(Collection resourceNames, + XdsResourceType resourceType); } } diff --git a/xds/src/main/java/io/grpc/xds/client/XdsClientImpl.java b/xds/src/main/java/io/grpc/xds/client/XdsClientImpl.java index 79147cd9862..0584a3dbfdd 100644 --- a/xds/src/main/java/io/grpc/xds/client/XdsClientImpl.java +++ b/xds/src/main/java/io/grpc/xds/client/XdsClientImpl.java @@ -18,7 +18,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.client.Bootstrapper.XDSTP_SCHEME; import static io.grpc.xds.client.XdsResourceType.ParsedResource; import static io.grpc.xds.client.XdsResourceType.ValidatedResourceUpdate; @@ -26,14 +25,15 @@ import com.google.common.base.Joiner; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.google.protobuf.Any; import io.grpc.Internal; import io.grpc.InternalLogId; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.BackoffPolicy; @@ -41,30 +41,34 @@ import io.grpc.xds.client.Bootstrapper.AuthorityInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.XdsClient.ResourceStore; -import io.grpc.xds.client.XdsClient.XdsResponseHandler; import io.grpc.xds.client.XdsLogger.XdsLogLevel; -import java.net.URI; +import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.concurrent.Executor; +import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import javax.annotation.Nullable; /** * XdsClient implementation. */ @Internal -public final class XdsClientImpl extends XdsClient implements XdsResponseHandler, ResourceStore { +public final class XdsClientImpl extends XdsClient implements ResourceStore { // Longest time to wait, since the subscription to some resource, for concluding its absence. @VisibleForTesting public static final int INITIAL_RESOURCE_FETCH_TIMEOUT_SEC = 15; + public static final int EXTENDED_RESOURCE_FETCH_TIMEOUT_SEC = 30; private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @@ -74,21 +78,25 @@ public void uncaughtException(Thread t, Throwable e) { XdsLogLevel.ERROR, "Uncaught exception in XdsClient SynchronizationContext. Panic!", e); - // TODO(chengyuanzhang): better error handling. + // TODO: better error handling. throw new AssertionError(e); } }); - private final Map loadStatsManagerMap = - new HashMap<>(); - final Map serverLrsClientMap = - new HashMap<>(); - + private final Map loadStatsManagerMap = new HashMap<>(); + final Map serverLrsClientMap = new HashMap<>(); + /** Map of authority to its activated control plane client (affected by xds fallback). + * The last entry in the list for each value is the "active" CPC for the matching key */ + private final Map> activatedCpClients = new HashMap<>(); private final Map serverCpClientMap = new HashMap<>(); + + /** Maps resource type to the corresponding map of subscribers (keyed by resource name). */ private final Map, Map>> resourceSubscribers = new HashMap<>(); + /** Maps typeUrl to the corresponding XdsResourceType. */ private final Map> subscribedResourceTypeUrls = new HashMap<>(); + private final XdsTransportFactory xdsTransportFactory; private final Bootstrapper.BootstrapInfo bootstrapInfo; private final ScheduledExecutorService timeService; @@ -100,6 +108,7 @@ public void uncaughtException(Thread t, Throwable e) { private final XdsLogger logger; private volatile boolean isShutdown; private final MessagePrettyPrinter messagePrinter; + private final XdsClientMetricReporter metricReporter; public XdsClientImpl( XdsTransportFactory xdsTransportFactory, @@ -109,7 +118,8 @@ public XdsClientImpl( Supplier stopwatchSupplier, TimeProvider timeProvider, MessagePrettyPrinter messagePrinter, - Object securityConfig) { + Object securityConfig, + XdsClientMetricReporter metricReporter) { this.xdsTransportFactory = xdsTransportFactory; this.bootstrapInfo = bootstrapInfo; this.timeService = timeService; @@ -118,53 +128,12 @@ public XdsClientImpl( this.timeProvider = timeProvider; this.messagePrinter = messagePrinter; this.securityConfig = securityConfig; + this.metricReporter = metricReporter; logId = InternalLogId.allocate("xds-client", null); logger = XdsLogger.withLogId(logId); logger.log(XdsLogLevel.INFO, "Created"); } - @Override - public void handleResourceResponse( - XdsResourceType xdsResourceType, ServerInfo serverInfo, String versionInfo, - List resources, String nonce, ProcessingTracker processingTracker) { - checkNotNull(xdsResourceType, "xdsResourceType"); - syncContext.throwIfNotInThisSynchronizationContext(); - Set toParseResourceNames = - xdsResourceType.shouldRetrieveResourceKeysForArgs() - ? getResourceKeys(xdsResourceType) - : null; - XdsResourceType.Args args = new XdsResourceType.Args(serverInfo, versionInfo, nonce, - bootstrapInfo, securityConfig, toParseResourceNames); - handleResourceUpdate(args, resources, xdsResourceType, processingTracker); - } - - @Override - public void handleStreamClosed(Status error) { - syncContext.throwIfNotInThisSynchronizationContext(); - cleanUpResourceTimers(); - for (Map> subscriberMap : - resourceSubscribers.values()) { - for (ResourceSubscriber subscriber : subscriberMap.values()) { - if (!subscriber.hasResult()) { - subscriber.onError(error, null); - } - } - } - } - - @Override - public void handleStreamRestarted(ServerInfo serverInfo) { - syncContext.throwIfNotInThisSynchronizationContext(); - for (Map> subscriberMap : - resourceSubscribers.values()) { - for (ResourceSubscriber subscriber : subscriberMap.values()) { - if (subscriber.serverInfo.equals(serverInfo)) { - subscriber.restartTimer(); - } - } - } - } - @Override public void shutdown() { syncContext.execute( @@ -181,7 +150,8 @@ public void run() { for (final LoadReportClient lrsClient : serverLrsClientMap.values()) { lrsClient.stopLoadReporting(); } - cleanUpResourceTimers(); + cleanUpResourceTimers(null); + activatedCpClients.clear(); } }); } @@ -196,20 +166,53 @@ public Map> getSubscribedResourceTypesWithTypeUrl() { return Collections.unmodifiableMap(subscribedResourceTypeUrls); } + private ControlPlaneClient getActiveCpc(String authority) { + List controlPlaneClients = activatedCpClients.get(authority); + if (controlPlaneClients == null || controlPlaneClients.isEmpty()) { + return null; + } + + return controlPlaneClients.get(controlPlaneClients.size() - 1); + } + @Nullable @Override - public Collection getSubscribedResources(ServerInfo serverInfo, - XdsResourceType type) { + public Collection getSubscribedResources( + ServerInfo serverInfo, XdsResourceType type) { + ControlPlaneClient targetCpc = serverCpClientMap.get(serverInfo); + if (targetCpc == null) { + return null; + } + + // This should include all of the authorities that targetCpc or a fallback from it is serving + List authorities = activatedCpClients.entrySet().stream() + .filter(entry -> entry.getValue().contains(targetCpc)) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + Map> resources = resourceSubscribers.getOrDefault(type, Collections.emptyMap()); - ImmutableSet.Builder builder = ImmutableSet.builder(); - for (String key : resources.keySet()) { - if (resources.get(key).serverInfo.equals(serverInfo)) { - builder.add(key); + + Collection retVal = resources.entrySet().stream() + .filter(entry -> authorities.contains(entry.getValue().authority)) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + + return retVal.isEmpty() ? null : retVal; + } + + @Override + public void startMissingResourceTimers(Collection resourceNames, + XdsResourceType resourceType) { + Map> subscriberMap = + resourceSubscribers.get(resourceType); + + for (String resourceName : resourceNames) { + ResourceSubscriber subscriber = subscriberMap.get(resourceName); + if (subscriber.respTimer == null && !subscriber.hasResult()) { + subscriber.restartTimer(); } } - Collection retVal = builder.build(); - return retVal.isEmpty() ? null : retVal; } // As XdsClient APIs becomes resource agnostic, subscribed resource types are dynamic. @@ -225,7 +228,7 @@ public void run() { // A map from a "resource type" to a map ("resource name": "resource metadata") ImmutableMap.Builder, Map> metadataSnapshot = ImmutableMap.builder(); - for (XdsResourceType resourceType: resourceSubscribers.keySet()) { + for (XdsResourceType resourceType : resourceSubscribers.keySet()) { ImmutableMap.Builder metadataMap = ImmutableMap.builder(); for (Map.Entry> resourceEntry : resourceSubscribers.get(resourceType).entrySet()) { @@ -246,9 +249,9 @@ public Object getSecurityConfig() { @Override public void watchXdsResource(XdsResourceType type, - String resourceName, - ResourceWatcher watcher, - Executor watcherExecutor) { + String resourceName, + ResourceWatcher watcher, + Executor watcherExecutor) { syncContext.execute(new Runnable() { @Override @SuppressWarnings("unchecked") @@ -259,36 +262,125 @@ public void run() { } ResourceSubscriber subscriber = (ResourceSubscriber) resourceSubscribers.get(type).get(resourceName); + if (subscriber == null) { logger.log(XdsLogLevel.INFO, "Subscribe {0} resource {1}", type, resourceName); subscriber = new ResourceSubscriber<>(type, resourceName); resourceSubscribers.get(type).put(resourceName, subscriber); - if (subscriber.controlPlaneClient != null) { - subscriber.controlPlaneClient.adjustResourceSubscription(type); + + if (subscriber.errorDescription == null) { + CpcWithFallbackState cpcToUse = manageControlPlaneClient(subscriber); + if (cpcToUse.cpc != null) { + cpcToUse.cpc.adjustResourceSubscription(type); + } } } + subscriber.addWatcher(watcher, watcherExecutor); } }); } + /** + * Gets a ControlPlaneClient for the subscriber's authority, creating one if necessary. + * If there already was an active CPC for this authority, and it is different from the one + * identified, then do fallback to the identified one (cpcToUse). + * + * @return identified CPC or {@code null} (if there are no valid ServerInfos associated with the + * subscriber's authority or CPC's for all are in backoff), and whether did a fallback. + */ + @VisibleForTesting + private CpcWithFallbackState manageControlPlaneClient( + ResourceSubscriber subscriber) { + + ControlPlaneClient cpcToUse; + boolean didFallback = false; + try { + cpcToUse = getOrCreateControlPlaneClient(subscriber.authority); + } catch (IllegalArgumentException e) { + if (subscriber.errorDescription == null) { + subscriber.errorDescription = "Bad configuration: " + e.getMessage(); + } + + subscriber.onError( + Status.INVALID_ARGUMENT.withDescription(subscriber.errorDescription), null); + return new CpcWithFallbackState(null, false); + } catch (IOException e) { + logger.log(XdsLogLevel.DEBUG, + "Could not create a control plane client for authority {0}: {1}", + subscriber.authority, e.getMessage()); + return new CpcWithFallbackState(null, false); + } + + ControlPlaneClient activeCpClient = getActiveCpc(subscriber.authority); + if (cpcToUse != activeCpClient) { + addCpcToAuthority(subscriber.authority, cpcToUse); // makes it active + if (activeCpClient != null) { + didFallback = cpcToUse != null && !cpcToUse.isInError(); + if (didFallback) { + logger.log(XdsLogLevel.INFO, "Falling back to XDS server {0}", + cpcToUse.getServerInfo().target()); + } else { + logger.log(XdsLogLevel.WARNING, "No working fallback XDS Servers found from {0}", + activeCpClient.getServerInfo().target()); + } + } + } + + return new CpcWithFallbackState(cpcToUse, didFallback); + } + + private void addCpcToAuthority(String authority, ControlPlaneClient cpcToUse) { + List controlPlaneClients = + activatedCpClients.computeIfAbsent(authority, k -> new ArrayList<>()); + + if (controlPlaneClients.contains(cpcToUse)) { + return; + } + + // if there are any missing CPCs between the last one and cpcToUse, add them + add cpcToUse + ImmutableList serverInfos = getServerInfos(authority); + for (int i = controlPlaneClients.size(); i < serverInfos.size(); i++) { + ServerInfo serverInfo = serverInfos.get(i); + ControlPlaneClient cpc = serverCpClientMap.get(serverInfo); + controlPlaneClients.add(cpc); + logger.log(XdsLogLevel.DEBUG, "Adding control plane client {0} to authority {1}", + cpc, authority); + cpcToUse.sendDiscoveryRequests(); + if (cpc == cpcToUse) { + break; + } + } + } + @Override public void cancelXdsResourceWatch(XdsResourceType type, - String resourceName, - ResourceWatcher watcher) { + String resourceName, + ResourceWatcher watcher) { syncContext.execute(new Runnable() { @Override @SuppressWarnings("unchecked") public void run() { ResourceSubscriber subscriber = (ResourceSubscriber) resourceSubscribers.get(type).get(resourceName); + if (subscriber == null) { + logger.log(XdsLogLevel.WARNING, "double cancel of resource watch for {0}:{1}", + type.typeName(), resourceName); + return; + } subscriber.removeWatcher(watcher); if (!subscriber.isWatched()) { subscriber.cancelResourceWatch(); resourceSubscribers.get(type).remove(resourceName); - if (subscriber.controlPlaneClient != null) { - subscriber.controlPlaneClient.adjustResourceSubscription(type); + + List controlPlaneClients = + activatedCpClients.get(subscriber.authority); + if (controlPlaneClients != null) { + controlPlaneClients.forEach((cpc) -> { + cpc.adjustResourceSubscription(type); + }); } + if (resourceSubscribers.get(type).isEmpty()) { resourceSubscribers.remove(type); subscribedResourceTypeUrls.remove(type.typeUrl()); @@ -318,9 +410,22 @@ public void run() { public LoadStatsManager2.ClusterLocalityStats addClusterLocalityStats( final ServerInfo serverInfo, String clusterName, @Nullable String edsServiceName, Locality locality) { + return addClusterLocalityStats(serverInfo, clusterName, edsServiceName, locality, null); + } + + @Override + public LoadStatsManager2.ClusterLocalityStats addClusterLocalityStats( + final ServerInfo serverInfo, + String clusterName, + @Nullable String edsServiceName, + Locality locality, + @Nullable BackendMetricPropagation backendMetricPropagation) { LoadStatsManager2 loadStatsManager = loadStatsManagerMap.get(serverInfo); + LoadStatsManager2.ClusterLocalityStats loadCounter = - loadStatsManager.getClusterLocalityStats(clusterName, edsServiceName, locality); + loadStatsManager.getClusterLocalityStats( + clusterName, edsServiceName, locality, backendMetricPropagation); + syncContext.execute(new Runnable() { @Override public void run() { @@ -341,30 +446,6 @@ public String toString() { return logId.toString(); } - @Override - protected void startSubscriberTimersIfNeeded(ServerInfo serverInfo) { - if (isShutDown()) { - return; - } - - syncContext.execute(new Runnable() { - @Override - public void run() { - if (isShutDown()) { - return; - } - - for (Map> subscriberMap : resourceSubscribers.values()) { - for (ResourceSubscriber subscriber : subscriberMap.values()) { - if (subscriber.serverInfo.equals(serverInfo) && subscriber.respTimer == null) { - subscriber.restartTimer(); - } - } - } - } - }); - } - private Set getResourceKeys(XdsResourceType xdsResourceType) { if (!resourceSubscribers.containsKey(xdsResourceType)) { return null; @@ -373,33 +454,77 @@ private Set getResourceKeys(XdsResourceType xdsResourceType) { return resourceSubscribers.get(xdsResourceType).keySet(); } - private void cleanUpResourceTimers() { + // cpcForThisStream is null when doing shutdown + private void cleanUpResourceTimers(ControlPlaneClient cpcForThisStream) { + Collection authoritiesForCpc = getActiveAuthorities(cpcForThisStream); + String target = cpcForThisStream == null ? "null" : cpcForThisStream.getServerInfo().target(); + logger.log(XdsLogLevel.DEBUG, "Cleaning up resource timers for CPC {0}, authorities {1}", + target, authoritiesForCpc); + for (Map> subscriberMap : resourceSubscribers.values()) { for (ResourceSubscriber subscriber : subscriberMap.values()) { - subscriber.stopTimer(); + if (cpcForThisStream == null || authoritiesForCpc.contains(subscriber.authority)) { + subscriber.stopTimer(); + } + } + } + } + + private ControlPlaneClient getOrCreateControlPlaneClient(String authority) throws IOException { + // Optimize for the common case of a working ads stream already exists for the authority + ControlPlaneClient activeCpc = getActiveCpc(authority); + if (activeCpc != null && !activeCpc.isInError()) { + return activeCpc; + } + + ImmutableList serverInfos = getServerInfos(authority); + if (serverInfos == null) { + throw new IllegalArgumentException("No xds servers found for authority " + authority); + } + + for (ServerInfo serverInfo : serverInfos) { + ControlPlaneClient cpc = getOrCreateControlPlaneClient(serverInfo); + if (cpc.isInError()) { + continue; } + return cpc; } + + // Everything existed and is in backoff so throw + throw new IOException("All xds transports for authority " + authority + " are in backoff"); } - public ControlPlaneClient getOrCreateControlPlaneClient(ServerInfo serverInfo) { + private ControlPlaneClient getOrCreateControlPlaneClient(ServerInfo serverInfo) { syncContext.throwIfNotInThisSynchronizationContext(); if (serverCpClientMap.containsKey(serverInfo)) { return serverCpClientMap.get(serverInfo); } - XdsTransportFactory.XdsTransport xdsTransport = xdsTransportFactory.create(serverInfo); + logger.log(XdsLogLevel.DEBUG, "Creating control plane client for {0}", serverInfo.target()); + XdsTransportFactory.XdsTransport xdsTransport; + try { + xdsTransport = xdsTransportFactory.create(serverInfo); + } catch (Exception e) { + String msg = String.format("Failed to create xds transport for %s: %s", + serverInfo.target(), e.getMessage()); + logger.log(XdsLogLevel.WARNING, msg); + xdsTransport = + new ControlPlaneClient.FailingXdsTransport(Status.UNAVAILABLE.withDescription(msg)); + } + ControlPlaneClient controlPlaneClient = new ControlPlaneClient( xdsTransport, serverInfo, bootstrapInfo.node(), - this, + new ResponseHandler(serverInfo), this, timeService, syncContext, backoffPolicyProvider, stopwatchSupplier, - this, - messagePrinter); + messagePrinter + ); + serverCpClientMap.put(serverInfo, controlPlaneClient); LoadStatsManager2 loadStatsManager = new LoadStatsManager2(stopwatchSupplier); @@ -419,45 +544,49 @@ public Map getServerLrsClientMap() { } @Nullable - private ServerInfo getServerInfo(String resource) { - if (resource.startsWith(XDSTP_SCHEME)) { - URI uri = URI.create(resource); - String authority = uri.getAuthority(); - if (authority == null) { - authority = ""; - } + private ImmutableList getServerInfos(String authority) { + if (authority != null) { AuthorityInfo authorityInfo = bootstrapInfo.authorities().get(authority); if (authorityInfo == null || authorityInfo.xdsServers().isEmpty()) { return null; } - return authorityInfo.xdsServers().get(0); + return authorityInfo.xdsServers(); } else { - return bootstrapInfo.servers().get(0); // use first server + return bootstrapInfo.servers(); } } @SuppressWarnings("unchecked") private void handleResourceUpdate( XdsResourceType.Args args, List resources, XdsResourceType xdsResourceType, - ProcessingTracker processingTracker) { + boolean isFirstResponse, ProcessingTracker processingTracker) { + ControlPlaneClient controlPlaneClient = serverCpClientMap.get(args.serverInfo); + + if (isFirstResponse) { + shutdownLowerPriorityCpcs(controlPlaneClient); + } + ValidatedResourceUpdate result = xdsResourceType.parse(args, resources); logger.log(XdsLogger.XdsLogLevel.INFO, "Received {0} Response version {1} nonce {2}. Parsed resources: {3}", - xdsResourceType.typeName(), args.versionInfo, args.nonce, result.unpackedResources); + xdsResourceType.typeName(), args.versionInfo, args.nonce, result.unpackedResources); Map> parsedResources = result.parsedResources; Set invalidResources = result.invalidResources; + metricReporter.reportResourceUpdates(Long.valueOf(parsedResources.size()), + Long.valueOf(invalidResources.size()), + args.getServerInfo().target(), xdsResourceType.typeUrl()); + List errors = result.errors; String errorDetail = null; if (errors.isEmpty()) { checkArgument(invalidResources.isEmpty(), "found invalid resources but missing errors"); - serverCpClientMap.get(args.serverInfo).ackResponse(xdsResourceType, args.versionInfo, - args.nonce); + controlPlaneClient.ackResponse(xdsResourceType, args.versionInfo, args.nonce); } else { errorDetail = Joiner.on('\n').join(errors); logger.log(XdsLogLevel.WARNING, "Failed processing {0} Response version {1} nonce {2}. Errors:\n{3}", xdsResourceType.typeName(), args.versionInfo, args.nonce, errorDetail); - serverCpClientMap.get(args.serverInfo).nackResponse(xdsResourceType, args.nonce, errorDetail); + controlPlaneClient.nackResponse(xdsResourceType, args.nonce, errorDetail); } long updateTime = timeProvider.currentTimeNanos(); @@ -474,8 +603,21 @@ private void handleResourceUpdate( } if (invalidResources.contains(resourceName)) { - // The resource update is invalid. Capture the error without notifying the watchers. + // The resource update is invalid (NACK). Handle as a data error. subscriber.onRejected(args.versionInfo, updateTime, errorDetail); + + // Handle data errors (NACKs) based on fail_on_data_errors server feature. + // When xdsDataErrorHandlingEnabled is true and fail_on_data_errors is present, + // delete cached data so onError will call onResourceChanged instead of onAmbientError. + // When xdsDataErrorHandlingEnabled is false, use old behavior (always keep cached data). + if (BootstrapperImpl.xdsDataErrorHandlingEnabled && subscriber.data != null + && args.serverInfo.failOnDataErrors()) { + subscriber.data = null; + } + // Call onError, which will decide whether to call onResourceChanged or onAmbientError + // based on whether data exists after the above deletion. + subscriber.onError(Status.UNAVAILABLE.withDescription(errorDetail), processingTracker); + continue; } // Nothing else to do for incremental ADS resources. @@ -483,73 +625,106 @@ private void handleResourceUpdate( continue; } - // Handle State of the World ADS: invalid resources. - if (invalidResources.contains(resourceName)) { - // The resource is missing. Reuse the cached resource if possible. - if (subscriber.data == null) { - // No cached data. Notify the watchers of an invalid update. - subscriber.onError(Status.UNAVAILABLE.withDescription(errorDetail), processingTracker); - } - continue; - } - // For State of the World services, notify watchers when their watched resource is missing // from the ADS update. Note that we can only do this if the resource update is coming from // the same xDS server that the ResourceSubscriber is subscribed to. - if (subscriber.serverInfo.equals(args.serverInfo)) { - subscriber.onAbsent(processingTracker); + if (getActiveCpc(subscriber.authority) == controlPlaneClient) { + subscriber.onAbsent(processingTracker, args.serverInfo); } } } - /** - * Tracks a single subscribed resource. - */ + @Override + public Future reportServerConnections(ServerConnectionCallback callback) { + SettableFuture future = SettableFuture.create(); + syncContext.execute(() -> { + serverCpClientMap.forEach((serverInfo, controlPlaneClient) -> + callback.reportServerConnectionGauge( + !controlPlaneClient.isInError(), serverInfo.target())); + future.set(null); + }); + return future; + } + + private void shutdownLowerPriorityCpcs(ControlPlaneClient activatedCpc) { + // For each authority, remove any control plane clients, with lower priority than the activated + // one, from activatedCpClients storing them all in cpcsToShutdown. + Set cpcsToShutdown = new HashSet<>(); + for ( List cpcsForAuth : activatedCpClients.values()) { + if (cpcsForAuth == null) { + continue; + } + int index = cpcsForAuth.indexOf(activatedCpc); + if (index > -1) { + cpcsToShutdown.addAll(cpcsForAuth.subList(index + 1, cpcsForAuth.size())); + cpcsForAuth.subList(index + 1, cpcsForAuth.size()).clear(); // remove lower priority cpcs + } + } + + // Shutdown any lower priority control plane clients identified above that aren't still being + // used by another authority. If they are still being used let the XDS server know that we + // no longer are interested in subscriptions for authorities we are no longer responsible for. + for (ControlPlaneClient cpc : cpcsToShutdown) { + if (activatedCpClients.values().stream().noneMatch(list -> list.contains(cpc))) { + cpc.shutdown(); + serverCpClientMap.remove(cpc.getServerInfo()); + } else { + cpc.sendDiscoveryRequests(); + } + } + } + + + /** Tracks a single subscribed resource. */ private final class ResourceSubscriber { - @Nullable private final ServerInfo serverInfo; - @Nullable private final ControlPlaneClient controlPlaneClient; + @Nullable + private final String authority; private final XdsResourceType type; private final String resource; private final Map, Executor> watchers = new HashMap<>(); - @Nullable private T data; + @Nullable + private T data; private boolean absent; // Tracks whether the deletion has been ignored per bootstrap server feature. // See https://github.com/grpc/proposal/blob/master/A53-xds-ignore-resource-deletion.md private boolean resourceDeletionIgnored; - @Nullable private ScheduledHandle respTimer; - @Nullable private ResourceMetadata metadata; - @Nullable private String errorDescription; + @Nullable + private ScheduledHandle respTimer; + @Nullable + private ResourceMetadata metadata; + @Nullable + private String errorDescription; + @Nullable + private Status lastError; ResourceSubscriber(XdsResourceType type, String resource) { syncContext.throwIfNotInThisSynchronizationContext(); this.type = type; this.resource = resource; - this.serverInfo = getServerInfo(resource); - if (serverInfo == null) { + this.authority = getAuthorityFromResourceName(resource); + if (getServerInfos(authority) == null) { this.errorDescription = "Wrong configuration: xds server does not exist for resource " + resource; - this.controlPlaneClient = null; return; } + // Initialize metadata in UNKNOWN state to cover the case when resource subscriber, // is created but not yet requested because the client is in backoff. this.metadata = ResourceMetadata.newResourceMetadataUnknown(); + } - ControlPlaneClient controlPlaneClient = null; - try { - controlPlaneClient = getOrCreateControlPlaneClient(serverInfo); - if (controlPlaneClient.isInBackoff()) { - return; - } - } catch (IllegalArgumentException e) { - controlPlaneClient = null; - this.errorDescription = "Bad configuration: " + e.getMessage(); - return; - } finally { - this.controlPlaneClient = controlPlaneClient; - } - - restartTimer(); + @Override + public String toString() { + return "ResourceSubscriber{" + + "resource='" + resource + '\'' + + ", authority='" + authority + '\'' + + ", type=" + type + + ", watchers=" + watchers.size() + + ", data=" + data + + ", absent=" + absent + + ", resourceDeletionIgnored=" + resourceDeletionIgnored + + ", errorDescription='" + errorDescription + '\'' + + '}'; } void addWatcher(ResourceWatcher watcher, Executor watcherExecutor) { @@ -557,20 +732,28 @@ void addWatcher(ResourceWatcher watcher, Executor watcherExecutor) { watchers.put(watcher, watcherExecutor); T savedData = data; boolean savedAbsent = absent; + Status savedError = lastError; watcherExecutor.execute(() -> { if (errorDescription != null) { - watcher.onError(Status.INVALID_ARGUMENT.withDescription(errorDescription)); + watcher.onResourceChanged(StatusOr.fromStatus( + Status.INVALID_ARGUMENT.withDescription(errorDescription))); return; } if (savedData != null) { - notifyWatcher(watcher, savedData); + watcher.onResourceChanged(StatusOr.fromValue(savedData)); + if (savedError != null) { + watcher.onAmbientError(savedError); + } + } else if (savedError != null) { + watcher.onResourceChanged(StatusOr.fromStatus(savedError)); } else if (savedAbsent) { - watcher.onResourceDoesNotExist(resource); + watcher.onResourceChanged(StatusOr.fromStatus( + Status.NOT_FOUND.withDescription("Resource " + resource + " does not exist"))); } }); } - void removeWatcher(ResourceWatcher watcher) { + void removeWatcher(ResourceWatcher watcher) { checkArgument(watchers.containsKey(watcher), "watcher %s not registered", watcher); watchers.remove(watcher); } @@ -579,17 +762,22 @@ void restartTimer() { if (data != null || absent) { // resource already resolved return; } - if (!controlPlaneClient.isReady()) { // When client becomes ready, it triggers a restartTimer + ControlPlaneClient activeCpc = getActiveCpc(authority); + if (activeCpc == null || !activeCpc.isReady()) { + // When client becomes ready, it triggers a restartTimer for all relevant subscribers. return; } + ServerInfo serverInfo = activeCpc.getServerInfo(); + int timeoutSec = serverInfo.resourceTimerIsTransientError() + ? EXTENDED_RESOURCE_FETCH_TIMEOUT_SEC : INITIAL_RESOURCE_FETCH_TIMEOUT_SEC; class ResourceNotFound implements Runnable { @Override public void run() { logger.log(XdsLogLevel.INFO, "{0} resource {1} initial fetch timeout", type, resource); + onAbsent(null, activeCpc.getServerInfo()); respTimer = null; - onAbsent(null); } @Override @@ -601,9 +789,11 @@ public String toString() { // Initial fetch scheduled or rescheduled, transition metadata state to REQUESTED. metadata = ResourceMetadata.newResourceMetadataRequested(); + if (respTimer != null) { + respTimer.cancel(); + } respTimer = syncContext.schedule( - new ResourceNotFound(), INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS, - timeService); + new ResourceNotFound(), timeoutSec, TimeUnit.SECONDS, timeService); } void stopTimer() { @@ -624,8 +814,7 @@ void cancelResourceWatch() { message += " for which we previously ignored a deletion"; logLevel = XdsLogLevel.FORCE_INFO; } - logger.log(logLevel, message, type, resource, - serverInfo != null ? serverInfo.target() : "unknown"); + logger.log(logLevel, message, type, resource, getTarget()); } boolean isWatched() { @@ -642,23 +831,25 @@ void onData(ParsedResource parsedResource, String version, long updateTime, respTimer.cancel(); respTimer = null; } - this.metadata = ResourceMetadata - .newResourceMetadataAcked(parsedResource.getRawResource(), version, updateTime); ResourceUpdate oldData = this.data; this.data = parsedResource.getResourceUpdate(); + this.metadata = ResourceMetadata.newResourceMetadataAcked( + parsedResource.getRawResource(), version, updateTime); absent = false; + lastError = null; if (resourceDeletionIgnored) { logger.log(XdsLogLevel.FORCE_INFO, "xds server {0}: server returned new version " + "of resource for which we previously ignored a deletion: type {1} name {2}", - serverInfo != null ? serverInfo.target() : "unknown", type, resource); + getTarget(), type, resource); resourceDeletionIgnored = false; } if (!Objects.equals(oldData, data)) { + StatusOr update = StatusOr.fromValue(data); for (ResourceWatcher watcher : watchers.keySet()) { processingTracker.startTask(); watchers.get(watcher).execute(() -> { try { - notifyWatcher(watcher, data); + watcher.onResourceChanged(update); } finally { processingTracker.onComplete(); } @@ -667,37 +858,85 @@ void onData(ParsedResource parsedResource, String version, long updateTime, } } - void onAbsent(@Nullable ProcessingTracker processingTracker) { + private String getTarget() { + ControlPlaneClient activeCpc = getActiveCpc(authority); + return (activeCpc != null) + ? activeCpc.getServerInfo().target() + : "unknown"; + } + + void onAbsent(@Nullable ProcessingTracker processingTracker, ServerInfo serverInfo) { if (respTimer != null && respTimer.isPending()) { // too early to conclude absence return; } - // Ignore deletion of State of the World resources when this feature is on, - // and the resource is reusable. - boolean ignoreResourceDeletionEnabled = - serverInfo != null && serverInfo.ignoreResourceDeletion(); - if (ignoreResourceDeletionEnabled && type.isFullStateOfTheWorld() && data != null) { - if (!resourceDeletionIgnored) { - logger.log(XdsLogLevel.FORCE_WARNING, - "xds server {0}: ignoring deletion for resource type {1} name {2}}", - serverInfo.target(), type, resource); - resourceDeletionIgnored = true; + // Handle data errors (resource deletions) based on fail_on_data_errors server feature. + // When xdsDataErrorHandlingEnabled is true and fail_on_data_errors is not present, + // we treat deletions as ambient errors and keep using the cached resource. + // When fail_on_data_errors is present, we delete the cached resource and fail. + // When xdsDataErrorHandlingEnabled is false, use the old behavior (ignore_resource_deletion). + boolean ignoreResourceDeletionEnabled = serverInfo.ignoreResourceDeletion(); + boolean failOnDataErrors = serverInfo.failOnDataErrors(); + boolean xdsDataErrorHandlingEnabled = BootstrapperImpl.xdsDataErrorHandlingEnabled; + + if (type.isFullStateOfTheWorld() && data != null) { + // New behavior (per gRFC A88): Default is to treat deletions as ambient errors + if (xdsDataErrorHandlingEnabled && !failOnDataErrors) { + if (!resourceDeletionIgnored) { + logger.log(XdsLogLevel.FORCE_WARNING, + "xds server {0}: ignoring deletion for resource type {1} name {2}}", + serverInfo.target(), type, resource); + resourceDeletionIgnored = true; + } + Status deletionStatus = Status.NOT_FOUND.withDescription( + "Resource " + resource + " deleted from server"); + onAmbientError(deletionStatus, processingTracker); + return; + } + // Old behavior: Use ignore_resource_deletion server feature + if (!xdsDataErrorHandlingEnabled && ignoreResourceDeletionEnabled) { + if (!resourceDeletionIgnored) { + logger.log(XdsLogLevel.FORCE_WARNING, + "xds server {0}: ignoring deletion for resource type {1} name {2}}", + serverInfo.target(), type, resource); + resourceDeletionIgnored = true; + } + Status deletionStatus = Status.NOT_FOUND.withDescription( + "Resource " + resource + " deleted from server"); + onAmbientError(deletionStatus, processingTracker); + return; } - return; } logger.log(XdsLogLevel.INFO, "Conclude {0} resource {1} not exist", type, resource); if (!absent) { data = null; absent = true; - metadata = ResourceMetadata.newResourceMetadataDoesNotExist(); - for (ResourceWatcher watcher : watchers.keySet()) { + lastError = null; + + Status status; + if (respTimer == null) { + status = Status.NOT_FOUND.withDescription("Resource " + resource + " does not exist"); + metadata = ResourceMetadata.newResourceMetadataDoesNotExist(); + } else { + status = serverInfo.resourceTimerIsTransientError() + ? Status.UNAVAILABLE.withDescription( + "Timed out waiting for resource " + resource + " from xDS server") + : Status.NOT_FOUND.withDescription( + "Timed out waiting for resource " + resource + " from xDS server"); + metadata = serverInfo.resourceTimerIsTransientError() + ? ResourceMetadata.newResourceMetadataTimeout() + : ResourceMetadata.newResourceMetadataDoesNotExist(); + } + + StatusOr update = StatusOr.fromStatus(status); + for (Map.Entry, Executor> entry : watchers.entrySet()) { if (processingTracker != null) { processingTracker.startTask(); } - watchers.get(watcher).execute(() -> { + entry.getValue().execute(() -> { try { - watcher.onResourceDoesNotExist(resource); + entry.getKey().onResourceChanged(update); } finally { if (processingTracker != null) { processingTracker.onComplete(); @@ -720,14 +959,39 @@ void onError(Status error, @Nullable ProcessingTracker tracker) { Status errorAugmented = Status.fromCode(error.getCode()) .withDescription(description + "nodeID: " + bootstrapInfo.node().getId()) .withCause(error.getCause()); + this.lastError = errorAugmented; + + if (data != null) { + // We have cached data, so this is an ambient error. + onAmbientError(errorAugmented, tracker); + } else { + // No data, this is a definitive resource error. + StatusOr update = StatusOr.fromStatus(errorAugmented); + for (Map.Entry, Executor> entry : watchers.entrySet()) { + if (tracker != null) { + tracker.startTask(); + } + entry.getValue().execute(() -> { + try { + entry.getKey().onResourceChanged(update); + } finally { + if (tracker != null) { + tracker.onComplete(); + } + } + }); + } + } + } - for (ResourceWatcher watcher : watchers.keySet()) { + private void onAmbientError(Status error, @Nullable ProcessingTracker tracker) { + for (Map.Entry, Executor> entry : watchers.entrySet()) { if (tracker != null) { tracker.startTask(); } - watchers.get(watcher).execute(() -> { + entry.getValue().execute(() -> { try { - watcher.onError(errorAugmented); + entry.getKey().onAmbientError(error); } finally { if (tracker != null) { tracker.onComplete(); @@ -739,12 +1003,100 @@ void onError(Status error, @Nullable ProcessingTracker tracker) { void onRejected(String rejectedVersion, long rejectedTime, String rejectedDetails) { metadata = ResourceMetadata - .newResourceMetadataNacked(metadata, rejectedVersion, rejectedTime, rejectedDetails); + .newResourceMetadataNacked(metadata, rejectedVersion, rejectedTime, rejectedDetails, + data != null); + } + } + + private class ResponseHandler implements XdsResponseHandler { + final ServerInfo serverInfo; + + ResponseHandler(ServerInfo serverInfo) { + this.serverInfo = serverInfo; } - private void notifyWatcher(ResourceWatcher watcher, T update) { - watcher.onChanged(update); + @Override + public void handleResourceResponse( + XdsResourceType xdsResourceType, ServerInfo serverInfo, String versionInfo, + List resources, String nonce, boolean isFirstResponse, + ProcessingTracker processingTracker) { + checkNotNull(xdsResourceType, "xdsResourceType"); + syncContext.throwIfNotInThisSynchronizationContext(); + Set toParseResourceNames = + xdsResourceType.shouldRetrieveResourceKeysForArgs() + ? getResourceKeys(xdsResourceType) + : null; + XdsResourceType.Args args = new XdsResourceType.Args(serverInfo, versionInfo, nonce, + bootstrapInfo, securityConfig, toParseResourceNames); + handleResourceUpdate(args, resources, xdsResourceType, isFirstResponse, processingTracker); } + + @Override + public void handleStreamClosed(Status status, boolean shouldTryFallback) { + syncContext.throwIfNotInThisSynchronizationContext(); + + ControlPlaneClient cpcClosed = serverCpClientMap.get(serverInfo); + if (cpcClosed == null) { + logger.log(XdsLogLevel.DEBUG, + "Couldn't find closing CPC for {0}, so skipping cleanup and reporting", serverInfo); + return; + } + + cleanUpResourceTimers(cpcClosed); + + if (status.isOk()) { + return; // Not considered an error + } + + metricReporter.reportServerFailure(1L, serverInfo.target()); + + Collection authoritiesForClosedCpc = getActiveAuthorities(cpcClosed); + for (Map> subscriberMap : + resourceSubscribers.values()) { + for (ResourceSubscriber subscriber : subscriberMap.values()) { + if (!authoritiesForClosedCpc.contains(subscriber.authority)) { + continue; + } + // If subscriber already has data, this is an ambient error. + if (subscriber.hasResult()) { + subscriber.onError(status, null); + continue; + } + + // try to fallback to lower priority control plane client + if (shouldTryFallback && manageControlPlaneClient(subscriber).didFallback) { + authoritiesForClosedCpc.remove(subscriber.authority); + if (authoritiesForClosedCpc.isEmpty()) { + return; // optimization: no need to continue once all authorities have done fallback + } + continue; // since we did fallback, don't consider it an error + } + + subscriber.onError(status, null); + } + } + } + } + + private static class CpcWithFallbackState { + ControlPlaneClient cpc; + boolean didFallback; + + private CpcWithFallbackState(ControlPlaneClient cpc, boolean didFallback) { + this.cpc = cpc; + this.didFallback = didFallback; + } + } + + private Collection getActiveAuthorities(ControlPlaneClient cpc) { + List asList = activatedCpClients.entrySet().stream() + .filter(entry -> !entry.getValue().isEmpty() + && cpc == entry.getValue().get(entry.getValue().size() - 1)) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + + // Since this is usually used for contains, use a set when the list is large + return (asList.size() < 100) ? asList : new HashSet<>(asList); } } diff --git a/xds/src/main/java/io/grpc/xds/client/XdsClientMetricReporter.java b/xds/src/main/java/io/grpc/xds/client/XdsClientMetricReporter.java new file mode 100644 index 00000000000..a044d501759 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/client/XdsClientMetricReporter.java @@ -0,0 +1,48 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.client; + +import io.grpc.Internal; + +/** + * Interface for reporting metrics from the xDS client. + */ +@Internal +public interface XdsClientMetricReporter { + + /** + * Reports number of valid and invalid resources. + * + * @param validResourceCount Number of resources that were valid. + * @param invalidResourceCount Number of resources that were invalid. + * @param xdsServer Target URI of the xDS server with which the XdsClient is communicating. + * @param resourceType Type of XDS resource (e.g., "envoy.config.listener.v3.Listener"). + */ + default void reportResourceUpdates(long validResourceCount, long invalidResourceCount, + String xdsServer, String resourceType) { + } + + /** + * Reports number of xDS servers going from healthy to unhealthy. + * + * @param serverFailure Number of xDS server failures. + * @param xdsServer Target URI of the xDS server with which the XdsClient is communicating. + */ + default void reportServerFailure(long serverFailure, String xdsServer) { + } + +} diff --git a/xds/src/main/java/io/grpc/xds/client/XdsResourceType.java b/xds/src/main/java/io/grpc/xds/client/XdsResourceType.java index 8c3d31604e4..4d6e75b1809 100644 --- a/xds/src/main/java/io/grpc/xds/client/XdsResourceType.java +++ b/xds/src/main/java/io/grpc/xds/client/XdsResourceType.java @@ -20,7 +20,6 @@ import static io.grpc.xds.client.XdsClient.canonifyResourceName; import static io.grpc.xds.client.XdsClient.isResourceNameValid; -import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; @@ -34,18 +33,18 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; import javax.annotation.Nullable; @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10847") public abstract class XdsResourceType { + private static final Logger log = Logger.getLogger(XdsResourceType.class.getName()); + static final String TYPE_URL_RESOURCE = "type.googleapis.com/envoy.service.discovery.v3.Resource"; protected static final String TRANSPORT_SOCKET_NAME_TLS = "envoy.transport_sockets.tls"; - @VisibleForTesting - public static final String HASH_POLICY_FILTER_STATE_KEY = "io.grpc.channel_id"; - protected static final String TYPE_URL_CLUSTER_CONFIG = - "type.googleapis.com/envoy.extensions.clusters.aggregate.v3.ClusterConfig"; protected static final String TYPE_URL_TYPED_STRUCT_UDPA = "type.googleapis.com/udpa.type.v1.TypedStruct"; protected static final String TYPE_URL_TYPED_STRUCT = @@ -181,6 +180,16 @@ ValidatedResourceUpdate parse(Args args, List resources) { typeName(), unpackedClassName().getSimpleName(), cname, e.getMessage())); invalidResources.add(cname); continue; + } catch (Throwable t) { + log.log(Level.FINE, "Unexpected error in doParse()", t); + String errorMessage = t.getClass().getSimpleName(); + if (t.getMessage() != null) { + errorMessage = errorMessage + ": " + t.getMessage(); + } + errors.add(String.format("%s response '%s' unexpected error: %s", + typeName(), cname, errorMessage)); + invalidResources.add(cname); + continue; } // Resource parsed successfully. @@ -249,53 +258,4 @@ public ValidatedResourceUpdate(Map> parsedResources, this.errors = errors; } } - - @VisibleForTesting - public static final class StructOrError { - - /** - * Returns a {@link StructOrError} for the successfully converted data object. - */ - public static StructOrError fromStruct(T struct) { - return new StructOrError<>(struct); - } - - /** - * Returns a {@link StructOrError} for the failure to convert the data object. - */ - public static StructOrError fromError(String errorDetail) { - return new StructOrError<>(errorDetail); - } - - private final String errorDetail; - private final T struct; - - private StructOrError(T struct) { - this.struct = checkNotNull(struct, "struct"); - this.errorDetail = null; - } - - private StructOrError(String errorDetail) { - this.struct = null; - this.errorDetail = checkNotNull(errorDetail, "errorDetail"); - } - - /** - * Returns struct if exists, otherwise null. - */ - @VisibleForTesting - @Nullable - public T getStruct() { - return struct; - } - - /** - * Returns error detail if exists, otherwise null. - */ - @VisibleForTesting - @Nullable - public String getErrorDetail() { - return errorDetail; - } - } } diff --git a/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java b/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java index 39b80bbcc03..91b77b05d01 100644 --- a/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java +++ b/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java @@ -26,9 +26,12 @@ public static Matchers.HeaderMatcher parseHeaderMatcher( io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto) { switch (proto.getHeaderMatchSpecifierCase()) { case EXACT_MATCH: + @SuppressWarnings("deprecation") // gRFC A63: support indefinitely + String exactMatch = proto.getExactMatch(); return Matchers.HeaderMatcher.forExactValue( - proto.getName(), proto.getExactMatch(), proto.getInvertMatch()); + proto.getName(), exactMatch, proto.getInvertMatch()); case SAFE_REGEX_MATCH: + @SuppressWarnings("deprecation") // gRFC A63: support indefinitely String rawPattern = proto.getSafeRegexMatch().getRegex(); Pattern safeRegExMatch; try { @@ -49,14 +52,20 @@ public static Matchers.HeaderMatcher parseHeaderMatcher( return Matchers.HeaderMatcher.forPresent( proto.getName(), proto.getPresentMatch(), proto.getInvertMatch()); case PREFIX_MATCH: + @SuppressWarnings("deprecation") // gRFC A63: support indefinitely + String prefixMatch = proto.getPrefixMatch(); return Matchers.HeaderMatcher.forPrefix( - proto.getName(), proto.getPrefixMatch(), proto.getInvertMatch()); + proto.getName(), prefixMatch, proto.getInvertMatch()); case SUFFIX_MATCH: + @SuppressWarnings("deprecation") // gRFC A63: support indefinitely + String suffixMatch = proto.getSuffixMatch(); return Matchers.HeaderMatcher.forSuffix( - proto.getName(), proto.getSuffixMatch(), proto.getInvertMatch()); + proto.getName(), suffixMatch, proto.getInvertMatch()); case CONTAINS_MATCH: + @SuppressWarnings("deprecation") // gRFC A63: support indefinitely + String containsMatch = proto.getContainsMatch(); return Matchers.HeaderMatcher.forContains( - proto.getName(), proto.getContainsMatch(), proto.getInvertMatch()); + proto.getName(), containsMatch, proto.getInvertMatch()); case STRING_MATCH: return Matchers.HeaderMatcher.forString( proto.getName(), parseStringMatcher(proto.getStringMatch()), proto.getInvertMatch()); @@ -88,4 +97,25 @@ public static Matchers.StringMatcher parseStringMatcher( "Unknown StringMatcher match pattern: " + proto.getMatchPatternCase()); } } + + /** Translates envoy proto FractionalPercent to internal FractionMatcher. */ + public static Matchers.FractionMatcher parseFractionMatcher( + io.envoyproxy.envoy.type.v3.FractionalPercent proto) { + int denominator; + switch (proto.getDenominator()) { + case HUNDRED: + denominator = 100; + break; + case TEN_THOUSAND: + denominator = 10_000; + break; + case MILLION: + denominator = 1_000_000; + break; + case UNRECOGNIZED: + default: + throw new IllegalArgumentException("Unknown denominator type: " + proto.getDenominator()); + } + return Matchers.FractionMatcher.create(proto.getNumerator(), denominator); + } } diff --git a/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java b/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java new file mode 100644 index 00000000000..7da9a3ab6d9 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java @@ -0,0 +1,67 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import io.grpc.services.MetricReport; +import java.util.Map; +import java.util.OptionalDouble; + +/** + * Utilities for parsing and resolving metrics from {@link MetricReport}. + */ +public final class MetricReportUtils { + + private MetricReportUtils() {} + + /** + * Resolves a metric value from the report based on the given metric name. + * The logic checks for specific prefixes to determine where to look up the metric: + *

    + *
  • "cpu_utilization" -> getCpuUtilization()
  • + *
  • "application_utilization" -> getApplicationUtilization()
  • + *
  • "mem_utilization" -> getMemoryUtilization()
  • + *
  • "utilization." -> lookup in utilizationMetrics
  • + *
  • "named_metrics." -> lookup in namedMetrics
  • + *
+ * + * @param report The metric report to query. + * @param metricName The name of the custom metric to look up. + * @return The value of the metric if found, or empty if not found. + */ + public static OptionalDouble getMetric(MetricReport report, String metricName) { + if (metricName.equals("cpu_utilization")) { + return OptionalDouble.of(report.getCpuUtilization()); + } else if (metricName.equals("application_utilization")) { + return OptionalDouble.of(report.getApplicationUtilization()); + } else if (metricName.equals("mem_utilization")) { + return OptionalDouble.of(report.getMemoryUtilization()); + } else if (metricName.startsWith("utilization.")) { + Map map = report.getUtilizationMetrics(); + Double val = map.get(metricName.substring("utilization.".length())); + if (val != null) { + return OptionalDouble.of(val); + } + } else if (metricName.startsWith("named_metrics.")) { + Map map = report.getNamedMetrics(); + Double val = map.get(metricName.substring("named_metrics.".length())); + if (val != null) { + return OptionalDouble.of(val); + } + } + return OptionalDouble.empty(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/ProtobufJsonConverter.java b/xds/src/main/java/io/grpc/xds/internal/ProtobufJsonConverter.java new file mode 100644 index 00000000000..964c28c57e0 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/ProtobufJsonConverter.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.grpc.Internal; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Converter for Protobuf {@link Struct} to JSON-like {@link Map}. + */ +@Internal +public final class ProtobufJsonConverter { + private ProtobufJsonConverter() {} + + public static Map convertToJson(Struct struct) { + Map result = new HashMap<>(); + for (Map.Entry entry : struct.getFieldsMap().entrySet()) { + result.put(entry.getKey(), convertValue(entry.getValue())); + } + return result; + } + + private static Object convertValue(Value value) { + switch (value.getKindCase()) { + case STRUCT_VALUE: + return convertToJson(value.getStructValue()); + case LIST_VALUE: + return value.getListValue().getValuesList().stream() + .map(ProtobufJsonConverter::convertValue) + .collect(Collectors.toList()); + case NUMBER_VALUE: + return value.getNumberValue(); + case STRING_VALUE: + return value.getStringValue(); + case BOOL_VALUE: + return value.getBoolValue(); + case NULL_VALUE: + return null; + default: + throw new IllegalArgumentException("Unknown Value type: " + value.getKindCase()); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/XdsInternalAttributes.java b/xds/src/main/java/io/grpc/xds/internal/XdsInternalAttributes.java new file mode 100644 index 00000000000..b05230ea30b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/XdsInternalAttributes.java @@ -0,0 +1,27 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; + +public final class XdsInternalAttributes { + /** Name associated with individual address, if available (e.g., DNS name). */ + @EquivalentAddressGroup.Attr + public static final Attributes.Key ATTR_ADDRESS_NAME = + Attributes.Key.create("io.grpc.xds.XdsAttributes.addressName"); +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java new file mode 100644 index 00000000000..5aeb44c6e2a --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java @@ -0,0 +1,145 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import io.grpc.Status; +import io.grpc.xds.internal.Matchers; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesConfig; +import java.util.Optional; + +/** + * Represents the configuration for the external authorization (ext_authz) filter. This class + * encapsulates the settings defined in the + * {@link io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz} proto, providing a + * structured, immutable representation for use within gRPC. It includes configurations for the gRPC + * service used for authorization, header mutation rules, and other filter behaviors. + */ +@AutoValue +public abstract class ExtAuthzConfig { + + /** Creates a new builder for creating {@link ExtAuthzConfig} instances. */ + public static Builder builder() { + return new AutoValue_ExtAuthzConfig.Builder().allowedHeaders(ImmutableList.of()) + .disallowedHeaders(ImmutableList.of()).statusOnError(Status.PERMISSION_DENIED) + .filterEnabled(Matchers.FractionMatcher.create(100, 100)); + } + + /** + * The gRPC service configuration for the external authorization service. This is a required + * field. + * + * @see ExtAuthz#getGrpcService() + */ + public abstract GrpcServiceConfig grpcService(); + + /** + * Changes the filter's behavior on errors from the authorization service. If {@code true}, the + * filter will accept the request even if the authorization service fails or returns an error. + * + * @see ExtAuthz#getFailureModeAllow() + */ + public abstract boolean failureModeAllow(); + + /** + * Determines if the {@code x-envoy-auth-failure-mode-allowed} header is added to the request when + * {@link #failureModeAllow()} is true. + * + * @see ExtAuthz#getFailureModeAllowHeaderAdd() + */ + public abstract boolean failureModeAllowHeaderAdd(); + + /** + * Specifies if the peer certificate is sent to the external authorization service. + * + * @see ExtAuthz#getIncludePeerCertificate() + */ + public abstract boolean includePeerCertificate(); + + /** + * The gRPC status returned to the client when the authorization server returns an error or is + * unreachable. Defaults to {@code PERMISSION_DENIED}. + * + * @see io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz#getStatusOnError() + */ + public abstract Status statusOnError(); + + /** + * Specifies whether to deny requests when the filter is disabled. Defaults to {@code false}. + * + * @see ExtAuthz#getDenyAtDisable() + */ + public abstract boolean denyAtDisable(); + + /** + * The fraction of requests that will be checked by the authorization service. Defaults to all + * requests. + * + * @see ExtAuthz#getFilterEnabled() + */ + public abstract Matchers.FractionMatcher filterEnabled(); + + /** + * Specifies which request headers are sent to the authorization service. If empty, all headers + * are sent. + * + * @see ExtAuthz#getAllowedHeaders() + */ + public abstract ImmutableList allowedHeaders(); + + /** + * Specifies which request headers are not sent to the authorization service. This overrides + * {@link #allowedHeaders()}. + * + * @see ExtAuthz#getDisallowedHeaders() + */ + public abstract ImmutableList disallowedHeaders(); + + /** + * Rules for what modifications an ext_authz server may make to request headers. + * + * @see ExtAuthz#getDecoderHeaderMutationRules() + */ + public abstract Optional decoderHeaderMutationRules(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder grpcService(GrpcServiceConfig grpcService); + + public abstract Builder failureModeAllow(boolean failureModeAllow); + + public abstract Builder failureModeAllowHeaderAdd(boolean failureModeAllowHeaderAdd); + + public abstract Builder includePeerCertificate(boolean includePeerCertificate); + + public abstract Builder statusOnError(Status statusOnError); + + public abstract Builder denyAtDisable(boolean denyAtDisable); + + public abstract Builder filterEnabled(Matchers.FractionMatcher filterEnabled); + + public abstract Builder allowedHeaders(Iterable allowedHeaders); + + public abstract Builder disallowedHeaders(Iterable disallowedHeaders); + + public abstract Builder decoderHeaderMutationRules(HeaderMutationRulesConfig rules); + + public abstract ExtAuthzConfig build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java new file mode 100644 index 00000000000..78edea5c305 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +/** + * A custom exception for signaling errors during the parsing of external authorization + * (ext_authz) configurations. + */ +public class ExtAuthzParseException extends Exception { + + private static final long serialVersionUID = 0L; + + public ExtAuthzParseException(String message) { + super(message); + } + + public ExtAuthzParseException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java new file mode 100644 index 00000000000..cefc235e9eb --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java @@ -0,0 +1,87 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import io.grpc.CallCredentials; +import io.grpc.xds.client.ConfiguredChannelCredentials; +import java.time.Duration; +import java.util.Optional; + + +/** + * This class encapsulates the configuration for a gRPC service, including target URI, credentials, + * and other settings. This class is immutable and uses the AutoValue library for its + * implementation. + */ +@AutoValue +public abstract class GrpcServiceConfig { + + public static Builder builder() { + return new AutoValue_GrpcServiceConfig.Builder(); + } + + public abstract GoogleGrpcConfig googleGrpc(); + + public abstract Optional timeout(); + + public abstract ImmutableList initialMetadata(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder googleGrpc(GoogleGrpcConfig googleGrpc); + + public abstract Builder timeout(Duration timeout); + + public abstract Builder initialMetadata(ImmutableList initialMetadata); + + public abstract GrpcServiceConfig build(); + } + + /** + * This class encapsulates settings specific to Google's gRPC implementation, such as target URI + * and credentials. + */ + @AutoValue + public abstract static class GoogleGrpcConfig { + + public static Builder builder() { + return new AutoValue_GrpcServiceConfig_GoogleGrpcConfig.Builder(); + } + + public abstract String target(); + + public abstract ConfiguredChannelCredentials configuredChannelCredentials(); + + public abstract Optional callCredentials(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder target(String target); + + public abstract Builder configuredChannelCredentials( + ConfiguredChannelCredentials channelCredentials); + + public abstract Builder callCredentials(CallCredentials callCredentials); + + public abstract GoogleGrpcConfig build(); + } + } + + +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java new file mode 100644 index 00000000000..319ad3d07e3 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java @@ -0,0 +1,33 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +/** + * Exception thrown when there is an error parsing the gRPC service config. + */ +public class GrpcServiceParseException extends Exception { + + private static final long serialVersionUID = 1L; + + public GrpcServiceParseException(String message) { + super(message); + } + + public GrpcServiceParseException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/HeaderValue.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/HeaderValue.java new file mode 100644 index 00000000000..1b7bb283744 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/HeaderValue.java @@ -0,0 +1,44 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import com.google.auto.value.AutoValue; +import com.google.protobuf.ByteString; +import java.util.Optional; + +/** + * Represents a header to be mutated or added as part of xDS configuration. + * Avoids direct dependency on Envoy's proto objects while providing an immutable representation. + */ +@AutoValue +public abstract class HeaderValue { + + public static HeaderValue create(String key, String value) { + return new AutoValue_HeaderValue(key, Optional.of(value), Optional.empty()); + } + + public static HeaderValue create(String key, ByteString rawValue) { + return new AutoValue_HeaderValue(key, Optional.empty(), Optional.of(rawValue)); + } + + + public abstract String key(); + + public abstract Optional value(); + + public abstract Optional rawValue(); +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/HeaderValueValidationUtils.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/HeaderValueValidationUtils.java new file mode 100644 index 00000000000..ff0df11bdc5 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/HeaderValueValidationUtils.java @@ -0,0 +1,67 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import java.util.Locale; + +/** + * Utility class for validating HTTP headers. + */ +public final class HeaderValueValidationUtils { + public static final int MAX_HEADER_LENGTH = 16384; + + private HeaderValueValidationUtils() {} + + /** + * Returns true if the header key is disallowed for mutations or validation. + * + * @param key The header key (e.g., "content-type") + */ + public static boolean isDisallowed(String key) { + if (key.isEmpty() || key.length() > MAX_HEADER_LENGTH) { + return true; + } + if (!key.equals(key.toLowerCase(Locale.ROOT))) { + return true; + } + if (key.startsWith("grpc-")) { + return true; + } + if (key.startsWith(":") || key.equals("host")) { + return true; + } + return false; + } + + /** + * Returns true if the header value is disallowed. + * + * @param header The HeaderValue containing key and values + */ + public static boolean isDisallowed(HeaderValue header) { + if (isDisallowed(header.key())) { + return true; + } + if (header.value().isPresent() && header.value().get().length() > MAX_HEADER_LENGTH) { + return true; + } + if (header.rawValue().isPresent() && header.rawValue().get().size() > MAX_HEADER_LENGTH) { + return true; + } + return false; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java new file mode 100644 index 00000000000..b16ec7948ed --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java @@ -0,0 +1,77 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.auto.value.AutoValue; +import com.google.re2j.Pattern; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; +import java.util.Optional; + +/** + * Represents the configuration for header mutation rules, as defined in the + * {@link io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules} proto. + */ +@AutoValue +public abstract class HeaderMutationRulesConfig { + /** Creates a new builder for creating {@link HeaderMutationRulesConfig} instances. */ + public static Builder builder() { + return new AutoValue_HeaderMutationRulesConfig.Builder().disallowAll(false) + .disallowIsError(false); + } + + /** + * If set, allows any header that matches this regular expression. + * + * @see HeaderMutationRules#getAllowExpression() + */ + public abstract Optional allowExpression(); + + /** + * If set, disallows any header that matches this regular expression. + * + * @see HeaderMutationRules#getDisallowExpression() + */ + public abstract Optional disallowExpression(); + + /** + * If true, disallows all header mutations. + * + * @see HeaderMutationRules#getDisallowAll() + */ + public abstract boolean disallowAll(); + + /** + * If true, a disallowed header mutation will result in an error instead of being ignored. + * + * @see HeaderMutationRules#getDisallowIsError() + */ + public abstract boolean disallowIsError(); + + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder allowExpression(Pattern matcher); + + public abstract Builder disallowExpression(Pattern matcher); + + public abstract Builder disallowAll(boolean disallowAll); + + public abstract Builder disallowIsError(boolean disallowIsError); + + public abstract HeaderMutationRulesConfig build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParseException.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParseException.java new file mode 100644 index 00000000000..3782e84a54b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParseException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +/** + * Exception thrown when parsing header mutation rules fails. + */ +public final class HeaderMutationRulesParseException extends Exception { + private static final long serialVersionUID = 1L; + + public HeaderMutationRulesParseException(String message) { + super(message); + } + + public HeaderMutationRulesParseException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParser.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParser.java new file mode 100644 index 00000000000..f6bb2ec508d --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParser.java @@ -0,0 +1,55 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.re2j.Pattern; +import com.google.re2j.PatternSyntaxException; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; + +/** + * Parser for {@link io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules}. + */ +public final class HeaderMutationRulesParser { + + private HeaderMutationRulesParser() {} + + public static HeaderMutationRulesConfig parse(HeaderMutationRules proto) + throws HeaderMutationRulesParseException { + HeaderMutationRulesConfig.Builder builder = HeaderMutationRulesConfig.builder(); + builder.disallowAll(proto.getDisallowAll().getValue()); + builder.disallowIsError(proto.getDisallowIsError().getValue()); + if (proto.hasAllowExpression()) { + builder.allowExpression( + parseRegex(proto.getAllowExpression().getRegex(), "allow_expression")); + } + if (proto.hasDisallowExpression()) { + builder.disallowExpression( + parseRegex(proto.getDisallowExpression().getRegex(), "disallow_expression")); + } + return builder.build(); + } + + private static Pattern parseRegex(String regex, String fieldName) + throws HeaderMutationRulesParseException { + try { + return Pattern.compile(regex); + } catch (PatternSyntaxException e) { + throw new HeaderMutationRulesParseException( + "Invalid regex pattern for " + fieldName + ": " + e.getMessage(), e); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java index 90202b4820a..37d289c1c47 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java @@ -16,8 +16,6 @@ package io.grpc.xds.internal.security; -import static com.google.common.base.Preconditions.checkNotNull; - import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; @@ -44,17 +42,9 @@ final class ClientSslContextProviderFactory /** Creates an SslContextProvider from the given UpstreamTlsContext. */ @Override public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) { - checkNotNull(upstreamTlsContext, "upstreamTlsContext"); - checkNotNull( - upstreamTlsContext.getCommonTlsContext(), - "upstreamTlsContext should have CommonTlsContext"); - if (CommonTlsContextUtil.hasCertProviderInstance( - upstreamTlsContext.getCommonTlsContext())) { - return certProviderClientSslContextProviderFactory.getProvider( - upstreamTlsContext, - bootstrapInfo.node().toEnvoyProtoNode(), - bootstrapInfo.certProviders()); - } - throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!"); + return certProviderClientSslContextProviderFactory.getProvider( + upstreamTlsContext, + bootstrapInfo.node().toEnvoyProtoNode(), + bootstrapInfo.certProviders()); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/CommonTlsContextUtil.java b/xds/src/main/java/io/grpc/xds/internal/security/CommonTlsContextUtil.java index d3003b4a792..bd8a423e683 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/CommonTlsContextUtil.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/CommonTlsContextUtil.java @@ -18,33 +18,21 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; /** Class for utility functions for {@link CommonTlsContext}. */ public final class CommonTlsContextUtil { private CommonTlsContextUtil() {} - static boolean hasCertProviderInstance(CommonTlsContext commonTlsContext) { + public static boolean hasCertProviderInstance(CommonTlsContext commonTlsContext) { if (commonTlsContext == null) { return false; } - return hasIdentityCertificateProviderInstance(commonTlsContext) - || hasCertProviderValidationContext(commonTlsContext); - } - - private static boolean hasCertProviderValidationContext(CommonTlsContext commonTlsContext) { - if (commonTlsContext.hasCombinedValidationContext()) { - CombinedCertificateValidationContext combinedCertificateValidationContext = - commonTlsContext.getCombinedValidationContext(); - return combinedCertificateValidationContext.hasValidationContextCertificateProviderInstance(); - } - return hasValidationProviderInstance(commonTlsContext); - } - - private static boolean hasIdentityCertificateProviderInstance(CommonTlsContext commonTlsContext) { + @SuppressWarnings("deprecation") + boolean hasDeprecatedField = commonTlsContext.hasTlsCertificateCertificateProviderInstance(); return commonTlsContext.hasTlsCertificateProviderInstance() - || commonTlsContext.hasTlsCertificateCertificateProviderInstance(); + || hasDeprecatedField + || hasValidationProviderInstance(commonTlsContext); } private static boolean hasValidationProviderInstance(CommonTlsContext commonTlsContext) { @@ -52,7 +40,19 @@ private static boolean hasValidationProviderInstance(CommonTlsContext commonTlsC .hasCaCertificateProviderInstance()) { return true; } - return commonTlsContext.hasValidationContextCertificateProviderInstance(); + if (commonTlsContext.hasCombinedValidationContext()) { + CommonTlsContext.CombinedCertificateValidationContext combined = + commonTlsContext.getCombinedValidationContext(); + if (combined.hasDefaultValidationContext() + && combined.getDefaultValidationContext().hasCaCertificateProviderInstance()) { + return true; + } + // Check deprecated field (field 4) in CombinedValidationContext + @SuppressWarnings("deprecation") + boolean hasDeprecatedField = combined.hasValidationContextCertificateProviderInstance(); + return hasDeprecatedField; + } + return false; } /** @@ -65,4 +65,15 @@ public static CommonTlsContext.CertificateProviderInstance convert( .setInstanceName(pluginInstance.getInstanceName()) .setCertificateName(pluginInstance.getCertificateName()).build(); } + + public static boolean isUsingSystemRootCerts(CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasCombinedValidationContext()) { + return commonTlsContext.getCombinedValidationContext().getDefaultValidationContext() + .hasSystemRootCerts(); + } + if (commonTlsContext.hasValidationContext()) { + return commonTlsContext.getValidationContext().hasSystemRootCerts(); + } + return false; + } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java index 6bf66d022ff..59e114a89ff 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java @@ -30,9 +30,11 @@ import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; import javax.annotation.Nullable; +import javax.net.ssl.X509TrustManager; /** Base class for dynamic {@link SslContextProvider}s. */ @Internal @@ -40,7 +42,9 @@ public abstract class DynamicSslContextProvider extends SslContextProvider { protected final List pendingCallbacks = new ArrayList<>(); @Nullable protected final CertificateValidationContext staticCertificateValidationContext; - @Nullable protected SslContext sslContext; + @Nullable protected AbstractMap.SimpleImmutableEntry + sslContextAndTrustManager; + protected boolean autoSniSanValidationDoesNotApply; protected DynamicSslContextProvider( BaseTlsContext tlsContext, CertificateValidationContext staticCertValidationContext) { @@ -49,15 +53,21 @@ protected DynamicSslContextProvider( } @Nullable - public SslContext getSslContext() { - return sslContext; + public AbstractMap.SimpleImmutableEntry + getSslContextAndTrustManager() { + return sslContextAndTrustManager; } protected abstract CertificateValidationContext generateCertificateValidationContext(); + public void setAutoSniSanValidationDoesNotApply() { + autoSniSanValidationDoesNotApply = true; + } + /** Gets a server or client side SslContextBuilder. */ - protected abstract SslContextBuilder getSslContextBuilder( - CertificateValidationContext certificateValidationContext) + protected abstract AbstractMap.SimpleImmutableEntry + getSslContextBuilderAndTrustManager( + CertificateValidationContext certificateValidationContext) throws CertificateException, IOException, CertStoreException; // this gets called only when requested secrets are ready... @@ -65,7 +75,8 @@ protected final void updateSslContext() { try { CertificateValidationContext localCertValidationContext = generateCertificateValidationContext(); - SslContextBuilder sslContextBuilder = getSslContextBuilder(localCertValidationContext); + AbstractMap.SimpleImmutableEntry sslContextBuilderAndTm = + getSslContextBuilderAndTrustManager(localCertValidationContext); CommonTlsContext commonTlsContext = getCommonTlsContext(); if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) { List alpnList = commonTlsContext.getAlpnProtocolsList(); @@ -75,16 +86,18 @@ protected final void updateSslContext() { ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, alpnList); - sslContextBuilder.applicationProtocolConfig(apn); + sslContextBuilderAndTm.getKey().applicationProtocolConfig(apn); } List pendingCallbacksCopy; - SslContext sslContextCopy; + AbstractMap.SimpleImmutableEntry + sslContextAndExtendedX09TrustManagerCopy; synchronized (pendingCallbacks) { - sslContext = sslContextBuilder.build(); - sslContextCopy = sslContext; + sslContextAndTrustManager = new AbstractMap.SimpleImmutableEntry<>( + sslContextBuilderAndTm.getKey().build(), sslContextBuilderAndTm.getValue()); + sslContextAndExtendedX09TrustManagerCopy = sslContextAndTrustManager; pendingCallbacksCopy = clonePendingCallbacksAndClear(); } - makePendingCallbacks(sslContextCopy, pendingCallbacksCopy); + makePendingCallbacks(sslContextAndExtendedX09TrustManagerCopy, pendingCallbacksCopy); } catch (Exception e) { onError(Status.fromThrowable(e)); throw new RuntimeException(e); @@ -92,12 +105,13 @@ protected final void updateSslContext() { } protected final void callPerformCallback( - Callback callback, final SslContext sslContextCopy) { + Callback callback, + final AbstractMap.SimpleImmutableEntry sslContextAndTmCopy) { performCallback( new SslContextGetter() { @Override - public SslContext get() { - return sslContextCopy; + public AbstractMap.SimpleImmutableEntry get() { + return sslContextAndTmCopy; } }, callback @@ -108,10 +122,10 @@ public SslContext get() { public final void addCallback(Callback callback) { checkNotNull(callback, "callback"); // if there is a computed sslContext just send it - SslContext sslContextCopy = null; + AbstractMap.SimpleImmutableEntry sslContextCopy = null; synchronized (pendingCallbacks) { - if (sslContext != null) { - sslContextCopy = sslContext; + if (sslContextAndTrustManager != null) { + sslContextCopy = sslContextAndTrustManager; } else { pendingCallbacks.add(callback); } @@ -122,9 +136,11 @@ public final void addCallback(Callback callback) { } private final void makePendingCallbacks( - SslContext sslContextCopy, List pendingCallbacksCopy) { + AbstractMap.SimpleImmutableEntry + sslContextAndExtendedX509TrustManagerCopy, + List pendingCallbacksCopy) { for (Callback callback : pendingCallbacksCopy) { - callPerformCallback(callback, sslContextCopy); + callPerformCallback(callback, sslContextAndExtendedX509TrustManagerCopy); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/ReferenceCountingMap.java b/xds/src/main/java/io/grpc/xds/internal/security/ReferenceCountingMap.java index b7f56492fa5..08b8f6a325b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/ReferenceCountingMap.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/ReferenceCountingMap.java @@ -20,9 +20,9 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.errorprone.annotations.CheckReturnValue; import java.util.HashMap; import java.util.Map; -import javax.annotation.CheckReturnValue; import javax.annotation.concurrent.ThreadSafe; /** diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index 00659e53de1..a93299de11c 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -19,7 +19,9 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; import io.grpc.Attributes; +import io.grpc.Grpc; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; @@ -28,7 +30,10 @@ import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.ProtocolNegotiationEvent; -import io.grpc.xds.InternalXdsAttributes; +import io.grpc.xds.EnvoyServerProtoData; +import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.internal.XdsInternalAttributes; +import io.grpc.xds.internal.security.trust.CertificateUtils; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; @@ -36,12 +41,14 @@ import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; import java.security.cert.CertStoreException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; +import javax.net.ssl.X509TrustManager; /** * Provides client and server side gRPC {@link ProtocolNegotiator}s to provide the SSL @@ -60,8 +67,14 @@ private SecurityProtocolNegotiators() { private static final AsciiString SCHEME = AsciiString.of("http"); public static final Attributes.Key - ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER = - Attributes.Key.create("io.grpc.xds.internal.security.server.sslContextProviderSupplier"); + ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER = + Attributes.Key.create("io.grpc.xds.internal.security.server.sslContextProviderSupplier"); + + /** Attribute key for SslContextProviderSupplier (used from client) for a subchannel. */ + @Grpc.TransportAttr + public static final Attributes.Key + ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER = + Attributes.Key.create("io.grpc.xds.internal.security.SslContextProviderSupplier"); /** * Returns a {@link InternalProtocolNegotiator.ClientFactory}. @@ -130,14 +143,14 @@ public AsciiString scheme() { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { // check if SslContextProviderSupplier was passed via attributes SslContextProviderSupplier localSslContextProviderSupplier = - grpcHandler.getEagAttributes().get( - InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); + grpcHandler.getEagAttributes().get(ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); if (localSslContextProviderSupplier == null) { checkNotNull( fallbackProtocolNegotiator, "No TLS config and no fallbackProtocolNegotiator!"); return fallbackProtocolNegotiator.newHandler(grpcHandler); } - return new ClientSecurityHandler(grpcHandler, localSslContextProviderSupplier); + return new ClientSecurityHandler(grpcHandler, localSslContextProviderSupplier, + grpcHandler.getEagAttributes().get(XdsInternalAttributes.ATTR_ADDRESS_NAME)); } @Override @@ -180,10 +193,13 @@ static final class ClientSecurityHandler extends InternalProtocolNegotiators.ProtocolNegotiationHandler { private final GrpcHttp2ConnectionHandler grpcHandler; private final SslContextProviderSupplier sslContextProviderSupplier; + private final String sni; + private final boolean autoSniSanValidationDoesNotApply; ClientSecurityHandler( GrpcHttp2ConnectionHandler grpcHandler, - SslContextProviderSupplier sslContextProviderSupplier) { + SslContextProviderSupplier sslContextProviderSupplier, + String endpointHostname) { super( // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior @@ -197,6 +213,26 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { checkNotNull(grpcHandler, "grpcHandler"); this.grpcHandler = grpcHandler; this.sslContextProviderSupplier = sslContextProviderSupplier; + EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); + UpstreamTlsContext upstreamTlsContext = ((UpstreamTlsContext) tlsContext); + + String sniToUse = upstreamTlsContext.getAutoHostSni() + && !Strings.isNullOrEmpty(endpointHostname) + ? endpointHostname : upstreamTlsContext.getSni(); + if (sniToUse.isEmpty()) { + if (CertificateUtils.useChannelAuthorityIfNoSniApplicable) { + sniToUse = grpcHandler.getAuthority(); + } + autoSniSanValidationDoesNotApply = true; + } else { + autoSniSanValidationDoesNotApply = false; + } + sni = sniToUse; + } + + @VisibleForTesting + String getSni() { + return sni; } @Override @@ -208,7 +244,8 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { if (ctx.isRemoved()) { return; } @@ -217,7 +254,9 @@ public void updateSslContext(SslContext sslContext) { "ClientSecurityHandler.updateSslContext authority={0}, ctx.name={1}", new Object[]{grpcHandler.getAuthority(), ctx.name()}); ChannelHandler handler = - InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); + InternalProtocolNegotiators.tls( + sslContextAndTm.getKey(), sni, sslContextAndTm.getValue()) + .newHandler(grpcHandler); // Delegate rest of handshake to TLS handler ctx.pipeline().addAfter(ctx.name(), null, handler); @@ -229,8 +268,8 @@ public void updateSslContext(SslContext sslContext) { public void onException(Throwable throwable) { ctx.fireExceptionCaught(throwable); } - } - ); + }, + autoSniSanValidationDoesNotApply); } @Override @@ -351,9 +390,10 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSslContext(SslContext sslContext) { - ChannelHandler handler = - InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + ChannelHandler handler = InternalProtocolNegotiators.serverTls( + sslContextAndTm.getKey()).newHandler(grpcHandler); // Delegate rest of handshake to TLS handler if (!ctx.isRemoved()) { @@ -367,8 +407,8 @@ public void updateSslContext(SslContext sslContext) { public void onException(Throwable throwable) { ctx.fireExceptionCaught(throwable); } - } - ); + }, + false); } } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index a0c4ed37dfb..a5d14f72dc5 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -32,7 +32,9 @@ import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; +import java.util.AbstractMap; import java.util.concurrent.Executor; +import javax.net.ssl.X509TrustManager; /** * A SslContextProvider is a "container" or provider of SslContext. This is used by gRPC-xds to @@ -57,7 +59,8 @@ protected Callback(Executor executor) { } /** Informs callee of new/updated SslContext. */ - @VisibleForTesting public abstract void updateSslContext(SslContext sslContext); + @VisibleForTesting public abstract void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContext); /** Informs callee of an exception that was generated. */ @VisibleForTesting protected abstract void onException(Throwable throwable); @@ -119,8 +122,9 @@ protected final void performCallback( @Override public void run() { try { - SslContext sslContext = sslContextGetter.get(); - callback.updateSslContext(sslContext); + AbstractMap.SimpleImmutableEntry sslContextAndTm = + sslContextGetter.get(); + callback.updateSslContextAndExtendedX509TrustManager(sslContextAndTm); } catch (Throwable e) { callback.onException(e); } @@ -130,6 +134,6 @@ public void run() { /** Allows implementations to compute or get SslContext. */ protected interface SslContextGetter { - SslContext get() throws Exception; + AbstractMap.SimpleImmutableEntry get() throws Exception; } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 5f629273179..e5960dd95e8 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -25,7 +25,9 @@ import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; +import java.util.AbstractMap; import java.util.Objects; +import javax.net.ssl.X509TrustManager; /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} @@ -52,12 +54,16 @@ public BaseTlsContext getTlsContext() { } /** Updates SslContext via the passed callback. */ - public synchronized void updateSslContext(final SslContextProvider.Callback callback) { + public synchronized void updateSslContext( + final SslContextProvider.Callback callback, boolean autoSniSanValidationDoesNotApply) { checkNotNull(callback, "callback"); try { if (!shutdown) { if (sslContextProvider == null) { sslContextProvider = getSslContextProvider(); + if (tlsContext instanceof UpstreamTlsContext && autoSniSanValidationDoesNotApply) { + ((DynamicSslContextProvider) sslContextProvider).setAutoSniSanValidationDoesNotApply(); + } } } // we want to increment the ref-count so call findOrCreate again... @@ -66,8 +72,9 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call new SslContextProvider.Callback(callback.getExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { - callback.updateSslContext(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + callback.updateSslContextAndExtendedX509TrustManager(sslContextAndTm); releaseSslContextProvider(toRelease); } @@ -98,7 +105,8 @@ private void releaseSslContextProvider(SslContextProvider toRelease) { private SslContextProvider getSslContextProvider() { return tlsContext instanceof UpstreamTlsContext ? tlsContextManager.findOrCreateClientSslContextProvider((UpstreamTlsContext) tlsContext) - : tlsContextManager.findOrCreateServerSslContextProvider((DownstreamTlsContext) tlsContext); + : tlsContextManager.findOrCreateServerSslContextProvider( + (DownstreamTlsContext) tlsContext); } @VisibleForTesting public boolean isShutdown() { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java b/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java index 34a8863c52b..f56524d50b7 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java @@ -71,8 +71,6 @@ public SslContextProvider findOrCreateServerSslContextProvider( public SslContextProvider findOrCreateClientSslContextProvider( UpstreamTlsContext upstreamTlsContext) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); - CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder(); - upstreamTlsContext = new UpstreamTlsContext(builder.build()); return mapForClients.get(upstreamTlsContext); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index d4080101c1a..b4b72ae11c6 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -16,8 +16,6 @@ package io.grpc.xds.internal.security.certprovider; -import static com.google.common.base.Preconditions.checkNotNull; - import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; @@ -28,8 +26,11 @@ import io.netty.handler.ssl.SslContextBuilder; import java.security.cert.CertStoreException; import java.security.cert.X509Certificate; +import java.util.AbstractMap; +import java.util.Arrays; import java.util.Map; import javax.annotation.Nullable; +import javax.net.ssl.X509TrustManager; /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { @@ -46,26 +47,55 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP node, certProviders, certInstance, - checkNotNull(rootCertInstance, "Client SSL requires rootCertInstance"), + rootCertInstance, staticCertValidationContext, upstreamTlsContext, certificateProviderStore); } @Override - protected final SslContextBuilder getSslContextBuilder( - CertificateValidationContext certificateValidationContextdationContext) - throws CertStoreException { - SslContextBuilder sslContextBuilder = - GrpcSslContexts.forClient() - .trustManager( - new XdsTrustManagerFactory( - savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContextdationContext)); + protected final AbstractMap.SimpleImmutableEntry + getSslContextBuilderAndTrustManager( + CertificateValidationContext certificateValidationContext) + throws CertStoreException { + SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); + if (savedSpiffeTrustMap != null) { + sslContextBuilder = sslContextBuilder.trustManager( + new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContext, + autoSniSanValidationDoesNotApply + ? false : ((UpstreamTlsContext) tlsContext).getAutoSniSanValidation())); + } else if (savedTrustedRoots != null) { + sslContextBuilder = sslContextBuilder.trustManager( + new XdsTrustManagerFactory( + savedTrustedRoots.toArray(new X509Certificate[0]), + certificateValidationContext, + autoSniSanValidationDoesNotApply + ? false : ((UpstreamTlsContext) tlsContext).getAutoSniSanValidation())); + } else { + // Should be impossible because of the check in CertProviderClientSslContextProviderFactory + throw new IllegalStateException("There must be trusted roots or a SPIFFE trust map"); + } + XdsTrustManagerFactory trustManagerFactory; + if (savedSpiffeTrustMap != null) { + trustManagerFactory = new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContext, + ((UpstreamTlsContext) tlsContext).getAutoSniSanValidation()); + sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); + } else { + trustManagerFactory = new XdsTrustManagerFactory( + savedTrustedRoots.toArray(new X509Certificate[0]), + certificateValidationContext, + ((UpstreamTlsContext) tlsContext).getAutoSniSanValidation()); + sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); + } if (isMtls()) { sslContextBuilder.keyManager(savedKey, savedCertChain); } - return sslContextBuilder; + return new AbstractMap.SimpleImmutableEntry<>(sslContextBuilder, + io.grpc.internal.CertificateUtils.getX509ExtendedTrustManager( + Arrays.asList(trustManagerFactory.getTrustManagers()))); } - } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java index 21782741c2c..6205c1c3a63 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java @@ -25,6 +25,7 @@ import io.grpc.Internal; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; +import io.grpc.xds.internal.security.CommonTlsContextUtil; import io.grpc.xds.internal.security.SslContextProvider; import java.util.Map; import javax.annotation.Nullable; @@ -64,13 +65,17 @@ public SslContextProvider getProvider( = CertProviderSslContextProvider.getRootCertProviderInstance(commonTlsContext); CommonTlsContext.CertificateProviderInstance certInstance = CertProviderSslContextProvider.getCertProviderInstance(commonTlsContext); - return new CertProviderClientSslContextProvider( - node, - certProviders, - certInstance, - rootCertInstance, - staticCertValidationContext, - upstreamTlsContext, - certificateProviderStore); + if (CommonTlsContextUtil.hasCertProviderInstance(upstreamTlsContext.getCommonTlsContext()) + || CommonTlsContextUtil.isUsingSystemRootCerts(upstreamTlsContext.getCommonTlsContext())) { + return new CertProviderClientSslContextProvider( + node, + certProviders, + certInstance, + rootCertInstance, + staticCertValidationContext, + upstreamTlsContext, + certificateProviderStore); + } + throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!"); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java index e43452a53e1..3712b948142 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java @@ -30,8 +30,10 @@ import java.security.cert.CertStoreException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.AbstractMap; import java.util.Map; import javax.annotation.Nullable; +import javax.net.ssl.X509TrustManager; /** A server SslContext provider using CertificateProviderInstance to fetch secrets. */ final class CertProviderServerSslContextProvider extends CertProviderSslContextProvider { @@ -55,19 +57,25 @@ final class CertProviderServerSslContextProvider extends CertProviderSslContextP } @Override - protected final SslContextBuilder getSslContextBuilder( - CertificateValidationContext certificateValidationContextdationContext) - throws CertStoreException, CertificateException, IOException { + protected final AbstractMap.SimpleImmutableEntry + getSslContextBuilderAndTrustManager( + CertificateValidationContext certificateValidationContextdationContext) + throws CertStoreException, CertificateException, IOException { SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(savedKey, savedCertChain); - setClientAuthValues( - sslContextBuilder, - isMtls() - ? new XdsTrustManagerFactory( - savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContextdationContext) - : null); + XdsTrustManagerFactory trustManagerFactory = null; + if (isMtls() && savedSpiffeTrustMap != null) { + trustManagerFactory = new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContextdationContext, false); + } else if (isMtls()) { + trustManagerFactory = new XdsTrustManagerFactory( + savedTrustedRoots.toArray(new X509Certificate[0]), + certificateValidationContextdationContext, false); + } + setClientAuthValues(sslContextBuilder, trustManagerFactory); sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder); - return sslContextBuilder; + // TrustManager in the below return value is not used on the server side, so setting it to null + return new AbstractMap.SimpleImmutableEntry<>(sslContextBuilder, null); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 6570c619913..cb99ca6ad95 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -24,6 +24,7 @@ import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.internal.security.CommonTlsContextUtil; import io.grpc.xds.internal.security.DynamicSslContextProvider; +import java.io.Closeable; import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.List; @@ -34,13 +35,15 @@ abstract class CertProviderSslContextProvider extends DynamicSslContextProvider implements CertificateProvider.Watcher { - @Nullable private final CertificateProviderStore.Handle certHandle; - @Nullable private final CertificateProviderStore.Handle rootCertHandle; + @Nullable private final NoExceptionCloseable certHandle; + @Nullable private final NoExceptionCloseable rootCertHandle; @Nullable private final CertificateProviderInstance certInstance; - @Nullable private final CertificateProviderInstance rootCertInstance; + @Nullable protected final CertificateProviderInstance rootCertInstance; @Nullable protected PrivateKey savedKey; @Nullable protected List savedCertChain; @Nullable protected List savedTrustedRoots; + @Nullable protected Map> savedSpiffeTrustMap; + private final boolean isUsingSystemRootCerts; protected CertProviderSslContextProvider( Node node, @@ -53,24 +56,33 @@ protected CertProviderSslContextProvider( super(tlsContext, staticCertValidationContext); this.certInstance = certInstance; this.rootCertInstance = rootCertInstance; - String certInstanceName = null; - if (certInstance != null && certInstance.isInitialized()) { - certInstanceName = certInstance.getInstanceName(); + this.isUsingSystemRootCerts = rootCertInstance == null + && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()); + boolean createCertInstance = certInstance != null && certInstance.isInitialized(); + boolean createRootCertInstance = rootCertInstance != null && rootCertInstance.isInitialized(); + boolean sharedCertInstance = createCertInstance && createRootCertInstance + && rootCertInstance.getInstanceName().equals(certInstance.getInstanceName()); + if (createCertInstance) { CertificateProviderInfo certProviderInstanceConfig = - getCertProviderConfig(certProviders, certInstanceName); + getCertProviderConfig(certProviders, certInstance.getInstanceName()); + CertificateProvider.Watcher watcher = this; + if (!sharedCertInstance && !isUsingSystemRootCerts) { + watcher = new IgnoreUpdatesWatcher(watcher, /* ignoreRootCertUpdates= */ true); + } + // TODO: Previously we'd hang if certProviderInstanceConfig were null or + // certInstance.isInitialized() == false. Now we'll proceed. Those should be errors, or are + // they impossible and should be assertions? certHandle = certProviderInstanceConfig == null ? null : certificateProviderStore.createOrGetProvider( certInstance.getCertificateName(), certProviderInstanceConfig.pluginName(), certProviderInstanceConfig.config(), - this, - true); + watcher, + true)::close; } else { certHandle = null; } - if (rootCertInstance != null - && rootCertInstance.isInitialized() - && !rootCertInstance.getInstanceName().equals(certInstanceName)) { + if (createRootCertInstance && !sharedCertInstance) { CertificateProviderInfo certProviderInstanceConfig = getCertProviderConfig(certProviders, rootCertInstance.getInstanceName()); rootCertHandle = certProviderInstanceConfig == null ? null @@ -78,8 +90,13 @@ protected CertProviderSslContextProvider( rootCertInstance.getCertificateName(), certProviderInstanceConfig.pluginName(), certProviderInstanceConfig.config(), - this, - true); + new IgnoreUpdatesWatcher(this, /* ignoreRootCertUpdates= */ false), + false)::close; + } else if (rootCertInstance == null + && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext())) { + SystemRootCertificateProvider systemRootProvider = new SystemRootCertificateProvider(this); + systemRootProvider.start(); + rootCertHandle = systemRootProvider::close; } else { rootCertHandle = null; } @@ -95,10 +112,14 @@ protected static CertificateProviderInstance getCertProviderInstance( CommonTlsContext commonTlsContext) { if (commonTlsContext.hasTlsCertificateProviderInstance()) { return CommonTlsContextUtil.convert(commonTlsContext.getTlsCertificateProviderInstance()); - } else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { - return commonTlsContext.getTlsCertificateCertificateProviderInstance(); } - return null; + // Fall back to deprecated field for backward compatibility with Istio + @SuppressWarnings("deprecation") + CertificateProviderInstance deprecatedInstance = + commonTlsContext.hasTlsCertificateCertificateProviderInstance() + ? commonTlsContext.getTlsCertificateCertificateProviderInstance() + : null; + return deprecatedInstance; } @Nullable @@ -124,15 +145,6 @@ protected static CommonTlsContext.CertificateProviderInstance getRootCertProvide if (certValidationContext != null && certValidationContext.hasCaCertificateProviderInstance()) { return CommonTlsContextUtil.convert(certValidationContext.getCaCertificateProviderInstance()); } - if (commonTlsContext.hasCombinedValidationContext()) { - CommonTlsContext.CombinedCertificateValidationContext combinedValidationContext = - commonTlsContext.getCombinedValidationContext(); - if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) { - return combinedValidationContext.getValidationContextCertificateProviderInstance(); - } - } else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { - return commonTlsContext.getValidationContextCertificateProviderInstance(); - } return null; } @@ -149,18 +161,24 @@ public final void updateTrustedRoots(List trustedRoots) { updateSslContextWhenReady(); } + @Override + public final void updateSpiffeTrustMap(Map> spiffeTrustMap) { + savedSpiffeTrustMap = spiffeTrustMap; + updateSslContextWhenReady(); + } + private void updateSslContextWhenReady() { if (isMtls()) { - if (savedKey != null && savedTrustedRoots != null) { + if (savedKey != null && (savedTrustedRoots != null || savedSpiffeTrustMap != null)) { updateSslContext(); clearKeysAndCerts(); } - } else if (isClientSideTls()) { - if (savedTrustedRoots != null) { + } else if (isRegularTlsAndClientSide()) { + if (savedTrustedRoots != null || savedSpiffeTrustMap != null) { updateSslContext(); clearKeysAndCerts(); } - } else if (isServerSideTls()) { + } else if (isRegularTlsAndServerSide()) { if (savedKey != null) { updateSslContext(); clearKeysAndCerts(); @@ -170,19 +188,22 @@ private void updateSslContextWhenReady() { private void clearKeysAndCerts() { savedKey = null; - savedTrustedRoots = null; + if (!isUsingSystemRootCerts) { + savedTrustedRoots = null; + savedSpiffeTrustMap = null; + } savedCertChain = null; } protected final boolean isMtls() { - return certInstance != null && rootCertInstance != null; + return certInstance != null && (rootCertInstance != null || isUsingSystemRootCerts); } - protected final boolean isClientSideTls() { - return rootCertInstance != null && certInstance == null; + protected final boolean isRegularTlsAndClientSide() { + return (rootCertInstance != null || isUsingSystemRootCerts) && certInstance == null; } - protected final boolean isServerSideTls() { + protected final boolean isRegularTlsAndServerSide() { return certInstance != null && rootCertInstance == null; } @@ -200,4 +221,9 @@ public final void close() { rootCertHandle.close(); } } + + interface NoExceptionCloseable extends Closeable { + @Override + void close(); + } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProvider.java index a0d5d0fc69f..009bb7bf566 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProvider.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -45,6 +46,8 @@ public interface Watcher { void updateTrustedRoots(List trustedRoots); + void updateSpiffeTrustMap(Map> spiffeTrustMap); + void onError(Status errorStatus); } @@ -53,6 +56,7 @@ public static final class DistributorWatcher implements Watcher { private PrivateKey privateKey; private List certChain; private List trustedRoots; + private Map> spiffeTrustMap; @VisibleForTesting final Set downstreamWatchers = new HashSet<>(); @@ -65,6 +69,9 @@ synchronized void addWatcher(Watcher watcher) { if (trustedRoots != null) { sendLastTrustedRootsUpdate(watcher); } + if (spiffeTrustMap != null) { + sendLastSpiffeTrustMapUpdate(watcher); + } } synchronized void removeWatcher(Watcher watcher) { @@ -83,6 +90,10 @@ private void sendLastTrustedRootsUpdate(Watcher watcher) { watcher.updateTrustedRoots(trustedRoots); } + private void sendLastSpiffeTrustMapUpdate(Watcher watcher) { + watcher.updateSpiffeTrustMap(spiffeTrustMap); + } + @Override public synchronized void updateCertificate(PrivateKey key, List certChain) { checkNotNull(key, "key"); @@ -103,6 +114,14 @@ public synchronized void updateTrustedRoots(List trustedRoots) } } + @Override + public void updateSpiffeTrustMap(Map> spiffeTrustMap) { + this.spiffeTrustMap = spiffeTrustMap; + for (Watcher watcher : downstreamWatchers) { + sendLastSpiffeTrustMapUpdate(watcher); + } + } + @Override public synchronized void onError(Status errorStatus) { for (Watcher watcher : downstreamWatchers) { @@ -147,7 +166,7 @@ protected CertificateProvider(DistributorWatcher watcher, boolean notifyCertUpda @Override public abstract void close(); - /** Starts the cert refresh and watcher update cycle. */ + /** Starts the async cert refresh and watcher update cycle. */ public abstract void start(); private final DistributorWatcher watcher; diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java index dd945ce850e..304124cc7f2 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java @@ -16,10 +16,12 @@ package io.grpc.xds.internal.security.certprovider; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import io.grpc.Status; +import io.grpc.internal.SpiffeUtil; import io.grpc.internal.TimeProvider; import io.grpc.xds.internal.security.trust.CertificateUtils; import java.io.ByteArrayInputStream; @@ -30,6 +32,7 @@ import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.Arrays; +import java.util.HashMap; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -47,11 +50,13 @@ final class FileWatcherCertificateProvider extends CertificateProvider implement private final Path certFile; private final Path keyFile; private final Path trustFile; + private final Path spiffeTrustMapFile; private final long refreshIntervalInSeconds; @VisibleForTesting ScheduledFuture scheduledFuture; private FileTime lastModifiedTimeCert; private FileTime lastModifiedTimeKey; private FileTime lastModifiedTimeRoot; + private FileTime lastModifiedTimespiffeTrustMap; private boolean shutdown; FileWatcherCertificateProvider( @@ -60,6 +65,7 @@ final class FileWatcherCertificateProvider extends CertificateProvider implement String certFile, String keyFile, String trustFile, + String spiffeTrustMapFile, long refreshIntervalInSeconds, ScheduledExecutorService scheduledExecutorService, TimeProvider timeProvider) { @@ -69,7 +75,15 @@ final class FileWatcherCertificateProvider extends CertificateProvider implement this.timeProvider = checkNotNull(timeProvider, "timeProvider"); this.certFile = Paths.get(checkNotNull(certFile, "certFile")); this.keyFile = Paths.get(checkNotNull(keyFile, "keyFile")); - this.trustFile = Paths.get(checkNotNull(trustFile, "trustFile")); + checkArgument((trustFile != null || spiffeTrustMapFile != null), + "either trustFile or spiffeTrustMapFile must be present"); + if (spiffeTrustMapFile != null) { + this.spiffeTrustMapFile = Paths.get(spiffeTrustMapFile); + this.trustFile = null; + } else { + this.spiffeTrustMapFile = null; + this.trustFile = Paths.get(trustFile); + } this.refreshIntervalInSeconds = refreshIntervalInSeconds; } @@ -107,39 +121,48 @@ void checkAndReloadCertificates() { byte[] keyFileContents = Files.readAllBytes(keyFile); FileTime currentCertTime2 = Files.getLastModifiedTime(certFile); FileTime currentKeyTime2 = Files.getLastModifiedTime(keyFile); - if (!currentCertTime2.equals(currentCertTime)) { - return; - } - if (!currentKeyTime2.equals(currentKeyTime)) { - return; - } - try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents); - ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) { - PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream); - X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream); - getWatcher().updateCertificate(privateKey, Arrays.asList(certs)); + if (currentCertTime2.equals(currentCertTime) && currentKeyTime2.equals(currentKeyTime)) { + try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents); + ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) { + PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream); + X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream); + getWatcher().updateCertificate(privateKey, Arrays.asList(certs)); + } + lastModifiedTimeCert = currentCertTime; + lastModifiedTimeKey = currentKeyTime; } - lastModifiedTimeCert = currentCertTime; - lastModifiedTimeKey = currentKeyTime; } } catch (Throwable t) { generateErrorIfCurrentCertExpired(t); } try { - FileTime currentRootTime = Files.getLastModifiedTime(trustFile); - if (currentRootTime.equals(lastModifiedTimeRoot)) { - return; - } - byte[] rootFileContents = Files.readAllBytes(trustFile); - FileTime currentRootTime2 = Files.getLastModifiedTime(trustFile); - if (!currentRootTime2.equals(currentRootTime)) { - return; + if (spiffeTrustMapFile != null) { + FileTime currentSpiffeTime = Files.getLastModifiedTime(spiffeTrustMapFile); + if (!currentSpiffeTime.equals(lastModifiedTimespiffeTrustMap)) { + SpiffeUtil.SpiffeBundle trustBundle = SpiffeUtil + .loadTrustBundleFromFile(spiffeTrustMapFile.toString()); + getWatcher().updateSpiffeTrustMap(new HashMap<>(trustBundle.getBundleMap())); + lastModifiedTimespiffeTrustMap = currentSpiffeTime; + } } - try (ByteArrayInputStream rootStream = new ByteArrayInputStream(rootFileContents)) { - X509Certificate[] caCerts = CertificateUtils.toX509Certificates(rootStream); - getWatcher().updateTrustedRoots(Arrays.asList(caCerts)); + } catch (Throwable t) { + getWatcher().onError(Status.fromThrowable(t)); + } + try { + if (trustFile != null) { + FileTime currentRootTime = Files.getLastModifiedTime(trustFile); + if (!currentRootTime.equals(lastModifiedTimeRoot)) { + byte[] rootFileContents = Files.readAllBytes(trustFile); + FileTime currentRootTime2 = Files.getLastModifiedTime(trustFile); + if (currentRootTime2.equals(currentRootTime)) { + try (ByteArrayInputStream rootStream = new ByteArrayInputStream(rootFileContents)) { + X509Certificate[] caCerts = CertificateUtils.toX509Certificates(rootStream); + getWatcher().updateTrustedRoots(Arrays.asList(caCerts)); + } + lastModifiedTimeRoot = currentRootTime; + } + } } - lastModifiedTimeRoot = currentRootTime; } catch (Throwable t) { getWatcher().onError(Status.fromThrowable(t)); } @@ -195,6 +218,7 @@ FileWatcherCertificateProvider create( String certFile, String keyFile, String trustFile, + String spiffeTrustMapFile, long refreshIntervalInSeconds, ScheduledExecutorService scheduledExecutorService, TimeProvider timeProvider) { @@ -204,6 +228,7 @@ FileWatcherCertificateProvider create( certFile, keyFile, trustFile, + spiffeTrustMapFile, refreshIntervalInSeconds, scheduledExecutorService, timeProvider); @@ -220,6 +245,7 @@ abstract FileWatcherCertificateProvider create( String certFile, String keyFile, String trustFile, + String spiffeTrustMapFile, long refreshIntervalInSeconds, ScheduledExecutorService scheduledExecutorService, TimeProvider timeProvider); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java index c4b140442cb..e4871dc4c84 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java @@ -23,6 +23,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.protobuf.Duration; import com.google.protobuf.util.Durations; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; import io.grpc.internal.TimeProvider; import java.text.ParseException; @@ -33,11 +34,16 @@ /** * Provider of {@link FileWatcherCertificateProvider}s. */ -final class FileWatcherCertificateProviderProvider implements CertificateProviderProvider { +public final class FileWatcherCertificateProviderProvider implements CertificateProviderProvider { + // TODO(lwge): Remove the old env var check once it's confirmed to be unused. + @VisibleForTesting + public static boolean enableSpiffe = GrpcUtil.getFlag("GRPC_EXPERIMENTAL_SPIFFE_TRUST_BUNDLE_MAP", + false) || GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_MTLS_SPIFFE", false); private static final String CERT_FILE_KEY = "certificate_file"; private static final String KEY_FILE_KEY = "private_key_file"; private static final String ROOT_FILE_KEY = "ca_certificate_file"; + private static final String SPIFFE_TRUST_MAP_FILE_KEY = "spiffe_trust_bundle_map_file"; private static final String REFRESH_INTERVAL_KEY = "refresh_interval"; @VisibleForTesting static final long REFRESH_INTERVAL_DEFAULT = 600L; @@ -82,6 +88,7 @@ public CertificateProvider createCertificateProvider( configObj.certFile, configObj.keyFile, configObj.rootFile, + configObj.spiffeTrustMapFile, configObj.refrehInterval, scheduledExecutorServiceFactory.create(), timeProvider); @@ -98,7 +105,20 @@ private static Config validateAndTranslateConfig(Object config) { Config configObj = new Config(); configObj.certFile = checkForNullAndGet(map, CERT_FILE_KEY); configObj.keyFile = checkForNullAndGet(map, KEY_FILE_KEY); - configObj.rootFile = checkForNullAndGet(map, ROOT_FILE_KEY); + if (enableSpiffe) { + if (!map.containsKey(ROOT_FILE_KEY) && !map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY)) { + throw new NullPointerException( + String.format("either '%s' or '%s' is required in the config", + ROOT_FILE_KEY, SPIFFE_TRUST_MAP_FILE_KEY)); + } + if (map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY)) { + configObj.spiffeTrustMapFile = JsonUtil.getString(map, SPIFFE_TRUST_MAP_FILE_KEY); + } else { + configObj.rootFile = JsonUtil.getString(map, ROOT_FILE_KEY); + } + } else { + configObj.rootFile = checkForNullAndGet(map, ROOT_FILE_KEY); + } String refreshIntervalString = JsonUtil.getString(map, REFRESH_INTERVAL_KEY); if (refreshIntervalString != null) { try { @@ -139,6 +159,7 @@ static class Config { String certFile; String keyFile; String rootFile; + String spiffeTrustMapFile; Long refrehInterval; } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java new file mode 100644 index 00000000000..cd9d88be41b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java @@ -0,0 +1,68 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.security.certprovider; + +import static java.util.Objects.requireNonNull; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Status; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.Map; + +public final class IgnoreUpdatesWatcher implements CertificateProvider.Watcher { + private final CertificateProvider.Watcher delegate; + private final boolean ignoreRootCertUpdates; + + public IgnoreUpdatesWatcher( + CertificateProvider.Watcher delegate, boolean ignoreRootCertUpdates) { + this.delegate = requireNonNull(delegate, "delegate"); + this.ignoreRootCertUpdates = ignoreRootCertUpdates; + } + + @Override + public void updateCertificate(PrivateKey key, List certChain) { + if (ignoreRootCertUpdates) { + delegate.updateCertificate(key, certChain); + } + } + + @Override + public void updateTrustedRoots(List trustedRoots) { + if (!ignoreRootCertUpdates) { + delegate.updateTrustedRoots(trustedRoots); + } + } + + @Override + public void updateSpiffeTrustMap(Map> spiffeTrustMap) { + if (!ignoreRootCertUpdates) { + delegate.updateSpiffeTrustMap(spiffeTrustMap); + } + } + + @Override + public void onError(Status errorStatus) { + delegate.onError(errorStatus); + } + + @VisibleForTesting + public CertificateProvider.Watcher getDelegate() { + return delegate; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java new file mode 100644 index 00000000000..7c60f714e71 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java @@ -0,0 +1,71 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.security.certprovider; + +import io.grpc.Status; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + +/** + * An non-registered provider for CertProviderSslContextProvider to use the same code path for + * system root certs as provider-obtained certs. + */ +final class SystemRootCertificateProvider extends CertificateProvider { + public SystemRootCertificateProvider(CertificateProvider.Watcher watcher) { + super(new DistributorWatcher(), false); + getWatcher().addWatcher(watcher); + } + + @Override + public void start() { + try { + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + + List trustManagers = Arrays.asList(trustManagerFactory.getTrustManagers()); + List rootCerts = trustManagers.stream() + .filter(X509TrustManager.class::isInstance) + .map(X509TrustManager.class::cast) + .map(trustManager -> Arrays.asList(trustManager.getAcceptedIssuers())) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + getWatcher().updateTrustedRoots(rootCerts); + } catch (KeyStoreException | NoSuchAlgorithmException ex) { + getWatcher().onError(Status.UNAVAILABLE + .withDescription("Could not load system root certs") + .withCause(ex)); + } + } + + @Override + public void close() { + // Unnecessary because there's no more callbacks, but do it for good measure + for (Watcher watcher : getWatcher().getDownstreamWatchers()) { + getWatcher().removeWatcher(watcher); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java index 86b6dd95c3e..41a3980c123 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java @@ -16,6 +16,7 @@ package io.grpc.xds.internal.security.trust; +import io.grpc.internal.GrpcUtil; import java.io.BufferedInputStream; import java.io.File; import java.io.FileInputStream; @@ -29,6 +30,9 @@ * Contains certificate utility method(s). */ public final class CertificateUtils { + public static boolean useChannelAuthorityIfNoSniApplicable + = GrpcUtil.getFlag("GRPC_USE_CHANNEL_AUTHORITY_IF_NO_SNI_APPLICABLE", false); + /** * Generates X509Certificate array from a file on disk. * diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java index c9d83902ec2..664c5dd9362 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java @@ -17,6 +17,7 @@ package io.grpc.xds.internal.security.trust; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; @@ -33,6 +34,9 @@ import java.security.cert.CertStoreException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; import javax.net.ssl.ManagerFactoryParameters; @@ -54,26 +58,51 @@ public XdsTrustManagerFactory(CertificateValidationContext certificateValidation this( getTrustedCaFromCertContext(certificateValidationContext), certificateValidationContext, + false, false); } public XdsTrustManagerFactory( - X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext) - throws CertStoreException { - this(certs, staticCertificateValidationContext, true); + X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext, + boolean autoSniSanValidation) throws CertStoreException { + this(certs, staticCertificateValidationContext, true, autoSniSanValidation); + } + + public XdsTrustManagerFactory(Map> spiffeTrustMap, + CertificateValidationContext staticCertificateValidationContext, boolean autoSniSanValidation) + throws CertStoreException { + this(spiffeTrustMap, staticCertificateValidationContext, true, autoSniSanValidation); } private XdsTrustManagerFactory( X509Certificate[] certs, CertificateValidationContext certificateValidationContext, - boolean validationContextIsStatic) + boolean validationContextIsStatic, + boolean autoSniSanValidation) + throws CertStoreException { + if (validationContextIsStatic) { + checkArgument( + certificateValidationContext == null || !certificateValidationContext.hasTrustedCa() + || certificateValidationContext.hasSystemRootCerts(), + "only static certificateValidationContext expected"); + } + xdsX509TrustManager = createX509TrustManager( + certs, certificateValidationContext, autoSniSanValidation); + } + + private XdsTrustManagerFactory( + Map> spiffeTrustMap, + CertificateValidationContext certificateValidationContext, + boolean validationContextIsStatic, + boolean autoSniSanValidation) throws CertStoreException { if (validationContextIsStatic) { checkArgument( certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(), "only static certificateValidationContext expected"); + xdsX509TrustManager = createX509TrustManager( + spiffeTrustMap, certificateValidationContext, autoSniSanValidation); } - xdsX509TrustManager = createX509TrustManager(certs, certificateValidationContext); } private static X509Certificate[] getTrustedCaFromCertContext( @@ -99,7 +128,28 @@ private static X509Certificate[] getTrustedCaFromCertContext( @VisibleForTesting static XdsX509TrustManager createX509TrustManager( - X509Certificate[] certs, CertificateValidationContext certContext) throws CertStoreException { + X509Certificate[] certs, CertificateValidationContext certContext, + boolean autoSniSanValidation) + throws CertStoreException { + return new XdsX509TrustManager(certContext, createTrustManager(certs), autoSniSanValidation); + } + + @VisibleForTesting + static XdsX509TrustManager createX509TrustManager( + Map> spiffeTrustMapFile, + CertificateValidationContext certContext, boolean autoSniSanValidation) + throws CertStoreException { + checkNotNull(spiffeTrustMapFile, "spiffeTrustMapFile"); + Map delegates = new HashMap<>(); + for (Map.Entry> entry:spiffeTrustMapFile.entrySet()) { + delegates.put(entry.getKey(), createTrustManager( + entry.getValue().toArray(new X509Certificate[0]))); + } + return new XdsX509TrustManager(certContext, delegates, autoSniSanValidation); + } + + private static X509ExtendedTrustManager createTrustManager(X509Certificate[] certs) + throws CertStoreException { TrustManagerFactory tmf = null; try { tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); @@ -131,7 +181,7 @@ static XdsX509TrustManager createX509TrustManager( if (myDelegate == null) { throw new CertStoreException("Native X509 TrustManager not found."); } - return new XdsX509TrustManager(certContext, myDelegate); + return myDelegate; } @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java index 6181d70fa51..01f25dda6c7 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java @@ -19,19 +19,29 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Optional; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableMap; import com.google.re2j.Pattern; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.grpc.internal.SpiffeUtil; import java.net.Socket; import java.security.cert.CertificateException; import java.security.cert.CertificateParsingException; import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; +import java.util.HashSet; import java.util.List; import java.util.Locale; +import java.util.Map; +import java.util.Set; import javax.annotation.Nullable; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocket; @@ -51,13 +61,28 @@ final class XdsX509TrustManager extends X509ExtendedTrustManager implements X509 private static final int ALT_IPA_NAME = 7; private final X509ExtendedTrustManager delegate; + private final Map spiffeTrustMapDelegates; private final CertificateValidationContext certContext; + private final boolean autoSniSanValidation; XdsX509TrustManager(@Nullable CertificateValidationContext certContext, - X509ExtendedTrustManager delegate) { + X509ExtendedTrustManager delegate, + boolean autoSniSanValidation) { checkNotNull(delegate, "delegate"); this.certContext = certContext; this.delegate = delegate; + this.spiffeTrustMapDelegates = null; + this.autoSniSanValidation = autoSniSanValidation; + } + + XdsX509TrustManager(@Nullable CertificateValidationContext certContext, + Map spiffeTrustMapDelegates, + boolean autoSniSanValidation) { + checkNotNull(spiffeTrustMapDelegates, "spiffeTrustMapDelegates"); + this.spiffeTrustMapDelegates = ImmutableMap.copyOf(spiffeTrustMapDelegates); + this.certContext = certContext; + this.delegate = null; + this.autoSniSanValidation = autoSniSanValidation; } private static boolean verifyDnsNameInPattern( @@ -130,6 +155,9 @@ private static boolean verifyDnsNameExact( if (Strings.isNullOrEmpty(sanToVerifyExact)) { return false; } + if (sanToVerifyExact.contains("*")) { + return verifyDnsNameWildcard(altNameFromCert, sanToVerifyExact, ignoreCase); + } return ignoreCase ? sanToVerifyExact.equalsIgnoreCase(altNameFromCert) : sanToVerifyExact.equals(altNameFromCert); @@ -186,11 +214,11 @@ private static void verifySubjectAltNameInLeaf( * This is called from various check*Trusted methods. */ @VisibleForTesting - void verifySubjectAltNameInChain(X509Certificate[] peerCertChain) throws CertificateException { + void verifySubjectAltNameInChain(X509Certificate[] peerCertChain, + List verifyList) throws CertificateException { if (certContext == null) { return; } - List verifyList = certContext.getMatchSubjectAltNamesList(); if (verifyList.isEmpty()) { return; } @@ -202,29 +230,36 @@ void verifySubjectAltNameInChain(X509Certificate[] peerCertChain) throws Certifi } @Override + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) throws CertificateException { - delegate.checkClientTrusted(chain, authType, socket); - verifySubjectAltNameInChain(chain); + chooseDelegate(chain).checkClientTrusted(chain, authType, socket); + verifySubjectAltNameInChain(chain, certContext != null + ? certContext.getMatchSubjectAltNamesList() : new ArrayList<>()); } @Override + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine) throws CertificateException { - delegate.checkClientTrusted(chain, authType, sslEngine); - verifySubjectAltNameInChain(chain); + chooseDelegate(chain).checkClientTrusted(chain, authType, sslEngine); + verifySubjectAltNameInChain(chain, certContext != null + ? certContext.getMatchSubjectAltNamesList() : new ArrayList<>()); } @Override + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { - delegate.checkClientTrusted(chain, authType); - verifySubjectAltNameInChain(chain); + chooseDelegate(chain).checkClientTrusted(chain, authType); + verifySubjectAltNameInChain(chain, certContext != null + ? certContext.getMatchSubjectAltNamesList() : new ArrayList<>()); } @Override public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) throws CertificateException { + List sniMatchers = null; if (socket instanceof SSLSocket) { SSLSocket sslSocket = (SSLSocket) socket; SSLParameters sslParams = sslSocket.getSSLParameters(); @@ -232,32 +267,134 @@ public void checkServerTrusted(X509Certificate[] chain, String authType, Socket sslParams.setEndpointIdentificationAlgorithm(""); sslSocket.setSSLParameters(sslParams); } + sniMatchers = getAutoSniSanMatchers(sslParams); } - delegate.checkServerTrusted(chain, authType, socket); - verifySubjectAltNameInChain(chain); + if (sniMatchers.isEmpty() && certContext != null) { + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names + List sniMatchersTmp = certContext.getMatchSubjectAltNamesList(); + sniMatchers = sniMatchersTmp; + } + chooseDelegate(chain).checkServerTrusted(chain, authType, socket); + verifySubjectAltNameInChain(chain, sniMatchers); } @Override public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine) throws CertificateException { + List sniMatchers = null; SSLParameters sslParams = sslEngine.getSSLParameters(); if (sslParams != null) { sslParams.setEndpointIdentificationAlgorithm(""); sslEngine.setSSLParameters(sslParams); + sniMatchers = getAutoSniSanMatchers(sslParams); + } + if (sniMatchers.isEmpty() && certContext != null) { + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names + List sniMatchersTmp = certContext.getMatchSubjectAltNamesList(); + sniMatchers = sniMatchersTmp; } - delegate.checkServerTrusted(chain, authType, sslEngine); - verifySubjectAltNameInChain(chain); + chooseDelegate(chain).checkServerTrusted(chain, authType, sslEngine); + verifySubjectAltNameInChain(chain, sniMatchers); } @Override + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { - delegate.checkServerTrusted(chain, authType); - verifySubjectAltNameInChain(chain); + chooseDelegate(chain).checkServerTrusted(chain, authType); + verifySubjectAltNameInChain(chain, certContext != null + ? certContext.getMatchSubjectAltNamesList() : new ArrayList<>()); + } + + private List getAutoSniSanMatchers(SSLParameters sslParams) { + List sniNamesToMatch = new ArrayList<>(); + if (autoSniSanValidation) { + List serverNames = sslParams.getServerNames(); + if (serverNames != null) { + for (SNIServerName serverName : serverNames) { + if (serverName instanceof SNIHostName) { + SNIHostName sniHostName = (SNIHostName) serverName; + String hostName = sniHostName.getAsciiName(); + sniNamesToMatch.add(StringMatcher.newBuilder().setExact(hostName).build()); + } + } + } + } + return sniNamesToMatch; + } + + private X509ExtendedTrustManager chooseDelegate(X509Certificate[] chain) + throws CertificateException { + if (spiffeTrustMapDelegates != null) { + Optional spiffeId = SpiffeUtil.extractSpiffeId(chain); + if (!spiffeId.isPresent()) { + throw new CertificateException("Failed to extract SPIFFE ID from peer leaf certificate"); + } + String trustDomain = spiffeId.get().getTrustDomain(); + if (!spiffeTrustMapDelegates.containsKey(trustDomain)) { + throw new CertificateException(String.format("Spiffe Trust Map doesn't contain trust" + + " domain '%s' from peer leaf certificate", trustDomain)); + } + return spiffeTrustMapDelegates.get(trustDomain); + } else { + return delegate; + } } @Override public X509Certificate[] getAcceptedIssuers() { + if (spiffeTrustMapDelegates != null) { + Set result = new HashSet<>(); + for (X509ExtendedTrustManager tm: spiffeTrustMapDelegates.values()) { + result.addAll(Arrays.asList(tm.getAcceptedIssuers())); + } + return result.toArray(new X509Certificate[0]); + } return delegate.getAcceptedIssuers(); } + + private static boolean verifyDnsNameWildcard( + String altNameFromCert, String sanToVerify, boolean ignoreCase) { + String[] splitPattern = splitAtFirstDelimiter(ignoreCase + ? sanToVerify.toLowerCase(Locale.ROOT) : sanToVerify); + String[] splitDnsName = splitAtFirstDelimiter(ignoreCase + ? altNameFromCert.toLowerCase(Locale.ROOT) : altNameFromCert); + if (splitPattern == null || splitDnsName == null) { + return false; + } + if (splitDnsName[0].startsWith("xn--")) { + return false; + } + if (splitPattern[0].contains("*") + && !splitPattern[1].contains("*") + && !splitPattern[0].startsWith("xn--")) { + return splitDnsName[1].equals(splitPattern[1]) + && labelWildcardMatch(splitDnsName[0], splitPattern[0]); + } + return false; + } + + private static boolean labelWildcardMatch(String dnsLabel, String pattern) { + final char glob = '*'; + // Check the special case of a single * pattern, as it's common. + if (pattern.equals("*")) { + return !dnsLabel.isEmpty(); + } + int globIndex = pattern.indexOf(glob); + if (pattern.indexOf(glob, globIndex + 1) == -1) { + return dnsLabel.length() >= pattern.length() - 1 + && dnsLabel.startsWith(pattern.substring(0, globIndex)) + && dnsLabel.endsWith(pattern.substring(globIndex + 1)); + } + return false; + } + + @Nullable + private static String[] splitAtFirstDelimiter(String s) { + int index = s.indexOf('.'); + if (index == -1) { + return null; + } + return new String[]{s.substring(0, index), s.substring(index + 1)}; + } } diff --git a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java index ba03140d627..b37b9bc42e3 100644 --- a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java @@ -36,12 +36,16 @@ import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ClientCall; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Metadata; import io.grpc.Status; @@ -83,7 +87,7 @@ private OrcaOobUtil() {} * class WrrLoadbalancer extends LoadBalancer { * private final Helper originHelper; // the original Helper * - * public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + * public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { * // listener implements the logic for WRR's usage of backend metrics. * OrcaReportingHelper orcaHelper = * OrcaOobUtil.newOrcaReportingHelper(originHelper); @@ -236,6 +240,30 @@ protected Helper delegate() { return delegate; } + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + delegate.updateBalancingState(newState, new OrcaOobPicker(newPicker)); + } + + @VisibleForTesting + static final class OrcaOobPicker extends SubchannelPicker { + final SubchannelPicker delegate; + + OrcaOobPicker(SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + PickResult result = delegate.pickSubchannel(args); + Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof SubchannelImpl) { + return result.copyWithSubchannel(((SubchannelImpl) subchannel).delegate()); + } + return result; + } + } + @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { syncContext.throwIfNotInThisSynchronizationContext(); diff --git a/xds/src/main/java/io/grpc/xds/package-info.java b/xds/src/main/java/io/grpc/xds/package-info.java index 74fa88cfe38..9cc15cd5449 100644 --- a/xds/src/main/java/io/grpc/xds/package-info.java +++ b/xds/src/main/java/io/grpc/xds/package-info.java @@ -15,7 +15,7 @@ */ /** - * Library for gPRC proxyless service mesh using Envoy xDS protocol. + * Library for gRPC proxyless service mesh using Envoy xDS protocol. * *

The package currently includes a name resolver plugin and a family of load balancer plugins. * A gRPC channel for a target with {@code "xds:"} scheme will load the plugins and a diff --git a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider index e1c4d4aa427..04a2d9cf7a8 100644 --- a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider +++ b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider @@ -2,7 +2,6 @@ io.grpc.xds.CdsLoadBalancerProvider io.grpc.xds.PriorityLoadBalancerProvider io.grpc.xds.WeightedTargetLoadBalancerProvider io.grpc.xds.ClusterManagerLoadBalancerProvider -io.grpc.xds.ClusterResolverLoadBalancerProvider io.grpc.xds.ClusterImplLoadBalancerProvider io.grpc.xds.LeastRequestLoadBalancerProvider io.grpc.xds.RingHashLoadBalancerProvider diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index 0884587cd95..928520aded7 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -17,59 +17,76 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.XdsLbPolicies.CLUSTER_RESOLVER_POLICY_NAME; -import static org.junit.Assert.fail; +import static io.grpc.xds.XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME; +import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.google.common.collect.ImmutableList; +import com.github.xds.type.v3.TypedStruct; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.protobuf.Any; +import com.google.protobuf.Struct; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.Value; +import io.envoyproxy.envoy.config.cluster.v3.CircuitBreakers; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; +import io.envoyproxy.envoy.config.cluster.v3.OutlierDetection; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; +import io.envoyproxy.envoy.config.core.v3.ConfigSource; +import io.envoyproxy.envoy.config.core.v3.RoutingPriority; +import io.envoyproxy.envoy.config.core.v3.SelfConfigSource; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.core.v3.TransportSocket; +import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; +import io.envoyproxy.envoy.extensions.clusters.aggregate.v3.ClusterConfig; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; import io.grpc.Attributes; +import io.grpc.ChannelLogger; import io.grpc.ConnectivityState; -import io.grpc.EquivalentAddressGroup; -import io.grpc.InsecureChannelCredentials; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; -import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver; +import io.grpc.NameResolverRegistry; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; -import io.grpc.internal.ObjectPool; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.FakeClock; +import io.grpc.testing.GrpcCleanupRule; import io.grpc.util.GracefulSwitchLoadBalancerAccessor; import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig.DiscoveryMechanism; -import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; -import io.grpc.xds.EnvoyServerProtoData.SuccessRateEjection; -import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; -import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; -import io.grpc.xds.XdsClusterResource.CdsUpdate; -import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsResourceType; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.Executor; -import javax.annotation.Nullable; +import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -88,638 +105,543 @@ @RunWith(JUnit4.class) public class CdsLoadBalancer2Test { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + @Rule + public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + private static final String SERVER_NAME = "example.com"; private static final String CLUSTER = "cluster-foo.googleapis.com"; private static final String EDS_SERVICE_NAME = "backend-service-1.googleapis.com"; - private static final String DNS_HOST_NAME = "backend-service-dns.googleapis.com:443"; - private static final ServerInfo LRS_SERVER_INFO = - ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); - private final UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); - private final OutlierDetection outlierDetection = OutlierDetection.create( - null, null, null, null, SuccessRateEjection.create(null, null, null, null), null); - - - private static final SynchronizationContext syncContext = new SynchronizationContext( - new Thread.UncaughtExceptionHandler() { - @Override - public void uncaughtException(Thread t, Throwable e) { - throw new RuntimeException(e); - //throw new AssertionError(e); - } - }); + private static final String NODE_ID = "node-id"; + private final io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContext("cert-instance-name", true); + private static final Cluster EDS_CLUSTER = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .build(); + + private final FakeClock fakeClock = new FakeClock(); private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); private final List childBalancers = new ArrayList<>(); - private final FakeXdsClient xdsClient = new FakeXdsClient(); - private final ObjectPool xdsClientPool = new ObjectPool() { - @Override - public XdsClient getObject() { - xdsClientRefs++; - return xdsClient; - } - - @Override - public XdsClient returnObject(Object object) { - xdsClientRefs--; - return null; - } - }; + private final XdsTestControlPlaneService controlPlaneService = new XdsTestControlPlaneService(); + private final XdsClient xdsClient = XdsTestUtils.createXdsClient( + Arrays.asList("control-plane.example.com"), + serverInfo -> new GrpcXdsTransportFactory.GrpcXdsTransport( + InProcessChannelBuilder + .forName(serverInfo.target()) + .directExecutor() + .build()), + fakeClock); + private XdsDependencyManager xdsDepManager; @Mock private Helper helper; @Captor private ArgumentCaptor pickerCaptor; - private int xdsClientRefs; - private CdsLoadBalancer2 loadBalancer; + private CdsLoadBalancer2 loadBalancer; + private XdsConfig lastXdsConfig; @Before - public void setUp() { - when(helper.getSynchronizationContext()).thenReturn(syncContext); - lbRegistry.register(new FakeLoadBalancerProvider(CLUSTER_RESOLVER_POLICY_NAME)); + public void setUp() throws Exception { + lbRegistry.register(new FakeLoadBalancerProvider(PRIORITY_POLICY_NAME)); + lbRegistry.register(new FakeLoadBalancerProvider(CLUSTER_IMPL_POLICY_NAME)); lbRegistry.register(new FakeLoadBalancerProvider("round_robin")); + lbRegistry.register(new FakeLoadBalancerProvider("outlier_detection_experimental")); lbRegistry.register( new FakeLoadBalancerProvider("ring_hash_experimental", new RingHashLoadBalancerProvider())); lbRegistry.register(new FakeLoadBalancerProvider("least_request_experimental", new LeastRequestLoadBalancerProvider())); - loadBalancer = new CdsLoadBalancer2(helper, lbRegistry); - loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(Collections.emptyList()) - .setAttributes( - // Other attributes not used by cluster_resolver LB are omitted. - Attributes.newBuilder() - .set(InternalXdsAttributes.XDS_CLIENT_POOL, xdsClientPool) - .build()) - .setLoadBalancingPolicyConfig(new CdsConfig(CLUSTER)) - .build()); - assertThat(Iterables.getOnlyElement(xdsClient.watchers.keySet())).isEqualTo(CLUSTER); + lbRegistry.register(new FakeLoadBalancerProvider("wrr_locality_experimental", + new WrrLocalityLoadBalancerProvider())); + CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); + lbRegistry.register(cdsLoadBalancerProvider); + loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); + + cleanupRule.register(InProcessServerBuilder + .forName("control-plane.example.com") + .addService(controlPlaneService) + .directExecutor() + .build() + .start()); + + SynchronizationContext syncContext = new SynchronizationContext((t, e) -> { + throw new AssertionError(e); + }); + when(helper.getSynchronizationContext()).thenReturn(syncContext); + when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService()); + + NameResolver.Args nameResolverArgs = NameResolver.Args.newBuilder() + .setDefaultPort(8080) + .setProxyDetector((address) -> null) + .setSynchronizationContext(syncContext) + .setServiceConfigParser(mock(NameResolver.ServiceConfigParser.class)) + .setChannelLogger(mock(ChannelLogger.class)) + .setScheduledExecutorService(fakeClock.getScheduledExecutorService()) + .setNameResolverRegistry(new NameResolverRegistry()) + .build(); + + xdsDepManager = new XdsDependencyManager( + xdsClient, + syncContext, + SERVER_NAME, + SERVER_NAME, + nameResolverArgs); + + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, ImmutableMap.of( + SERVER_NAME, ControlPlaneRule.buildClientListener(SERVER_NAME, "my-route"))); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_RDS, ImmutableMap.of( + "my-route", XdsTestUtils.buildRouteConfiguration(SERVER_NAME, "my-route", CLUSTER))); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, ControlPlaneRule.buildClusterLoadAssignment( + "127.0.0.1", "", 1234, EDS_SERVICE_NAME))); } @After public void tearDown() { - loadBalancer.shutdown(); - assertThat(xdsClient.watchers).isEmpty(); - assertThat(xdsClientRefs).isEqualTo(0); + if (loadBalancer != null) { + shutdownLoadBalancer(); + } assertThat(childBalancers).isEmpty(); + + if (xdsDepManager != null) { + xdsDepManager.shutdown(); + } + xdsClient.shutdown(); } - @Test - public void discoverTopLevelEdsCluster() { - CdsUpdate update = - CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, - outlierDetection) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(childBalancers).hasSize(1); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.name).isEqualTo(CLUSTER_RESOLVER_POLICY_NAME); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 100L, upstreamTlsContext, outlierDetection); - assertThat( - GracefulSwitchLoadBalancerAccessor.getChildProvider(childLbConfig.lbConfig).getPolicyName()) - .isEqualTo("round_robin"); + private void shutdownLoadBalancer() { + LoadBalancer lb = this.loadBalancer; + this.loadBalancer = null; // Must avoid calling acceptResolvedAddresses after shutdown + lb.shutdown(); } @Test - public void discoverTopLevelLogicalDnsCluster() { - CdsUpdate update = - CdsUpdate.forLogicalDns(CLUSTER, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext) - .leastRequestLbPolicy(3).build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); + public void discoverTopLevelCluster() { + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN) + .setLrsServer(ConfigSource.newBuilder() + .setSelf(SelfConfigSource.getDefaultInstance())) + .setCircuitBreakers(CircuitBreakers.newBuilder() + .addThresholds(CircuitBreakers.Thresholds.newBuilder() + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .setTransportSocket(TransportSocket.newBuilder() + .setName("envoy.transport_sockets.tls") + .setTypedConfig(Any.pack(UpstreamTlsContext.newBuilder() + .setCommonTlsContext(upstreamTlsContext.getCommonTlsContext()) + .build()))) + .setOutlierDetection(OutlierDetection.getDefaultInstance()) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.name).isEqualTo(CLUSTER_RESOLVER_POLICY_NAME); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.LOGICAL_DNS, null, - DNS_HOST_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, null); - assertThat( - GracefulSwitchLoadBalancerAccessor.getChildProvider(childLbConfig.lbConfig).getPolicyName()) - .isEqualTo("least_request_experimental"); - LeastRequestConfig lrConfig = (LeastRequestConfig) - GracefulSwitchLoadBalancerAccessor.getChildConfig(childLbConfig.lbConfig); - assertThat(lrConfig.choiceCount).isEqualTo(3); + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); } @Test public void nonAggregateCluster_resourceNotExist_returnErrorPicker() { - xdsClient.deliverResourceNotExist(CLUSTER); + startXdsDepManager(); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); - assertPicker(pickerCaptor.getValue(), unavailable, null); + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + ": NOT_FOUND. " + + "Details: Timed out waiting for resource " + CLUSTER + + " from xDS server nodeID: " + NODE_ID; + Status unavailable = Status.UNAVAILABLE.withDescription(expectedDescription); + assertPickerStatus(pickerCaptor.getValue(), unavailable); assertThat(childBalancers).isEmpty(); } @Test public void nonAggregateCluster_resourceUpdate() { - CdsUpdate update = - CdsUpdate.forEds(CLUSTER, null, null, 100L, upstreamTlsContext, outlierDetection) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); + lbRegistry.register(new PriorityLoadBalancerProvider()); + Cluster cluster = EDS_CLUSTER.toBuilder() + .setCircuitBreakers(CircuitBreakers.newBuilder() + .addThresholds(CircuitBreakers.Thresholds.newBuilder() + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, null, null, null, - 100L, upstreamTlsContext, outlierDetection); - - update = CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, null, - outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - childLbConfig = (ClusterResolverConfig) childBalancer.config; - instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 200L, null, outlierDetection); + ClusterImplConfig childLbConfig = (ClusterImplConfig) childBalancer.config; + assertThat(childLbConfig.cluster).isEqualTo(CLUSTER); + assertThat(childLbConfig.maxConcurrentRequests).isEqualTo(100L); + + cluster = EDS_CLUSTER.toBuilder() + .setCircuitBreakers(CircuitBreakers.newBuilder() + .addThresholds(CircuitBreakers.Thresholds.newBuilder() + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(200)))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); + childBalancer = Iterables.getOnlyElement(childBalancers); + childLbConfig = (ClusterImplConfig) childBalancer.config; + assertThat(childLbConfig.maxConcurrentRequests).isEqualTo(200L); } @Test public void nonAggregateCluster_resourceRevoked() { - CdsUpdate update = - CdsUpdate.forLogicalDns(CLUSTER, DNS_HOST_NAME, null, 100L, upstreamTlsContext) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); + lbRegistry.register(new PriorityLoadBalancerProvider()); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, EDS_CLUSTER)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.LOGICAL_DNS, null, - DNS_HOST_NAME, null, 100L, upstreamTlsContext, null); + ClusterImplConfig childLbConfig = (ClusterImplConfig) childBalancer.config; + assertThat(childLbConfig.cluster).isEqualTo(CLUSTER); + + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); - xdsClient.deliverResourceNotExist(CLUSTER); assertThat(childBalancer.shutdown).isTrue(); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + ": NOT_FOUND. " + + "Details: Resource " + CLUSTER + " does not exist nodeID: " + NODE_ID; + Status unavailable = Status.UNAVAILABLE.withDescription(expectedDescription); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), unavailable, null); + assertPickerStatus(pickerCaptor.getValue(), unavailable); assertThat(childBalancer.shutdown).isTrue(); assertThat(childBalancers).isEmpty(); } @Test - public void discoverAggregateCluster() { - String cluster1 = "cluster-01.googleapis.com"; - String cluster2 = "cluster-02.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (aggr.), cluster2 (logical DNS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Arrays.asList(cluster1, cluster2)) - .ringHashLbPolicy(100L, 1000L).build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - assertThat(childBalancers).isEmpty(); - String cluster3 = "cluster-03.googleapis.com"; - String cluster4 = "cluster-04.googleapis.com"; - // cluster1 (aggr.) -> [cluster3 (EDS), cluster4 (EDS)] - CdsUpdate update1 = - CdsUpdate.forAggregate(cluster1, Arrays.asList(cluster3, cluster4)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - assertThat(xdsClient.watchers.keySet()).containsExactly( - CLUSTER, cluster1, cluster2, cluster3, cluster4); - assertThat(childBalancers).isEmpty(); - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); - assertThat(childBalancers).isEmpty(); - CdsUpdate update2 = - CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, null, 100L, null) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - assertThat(childBalancers).isEmpty(); - CdsUpdate update4 = - CdsUpdate.forEds(cluster4, null, LRS_SERVER_INFO, 300L, null, outlierDetection) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster4, update4); - assertThat(childBalancers).hasSize(1); // all non-aggregate clusters discovered + public void dynamicCluster() { + String clusterName = "cluster2"; + Cluster cluster = EDS_CLUSTER.toBuilder() + .setName(clusterName) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + clusterName, cluster, + CLUSTER, Cluster.newBuilder().setName(CLUSTER).build())); + startXdsDepManager(new CdsConfig(clusterName, /*dynamic=*/ true)); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.name).isEqualTo(CLUSTER_RESOLVER_POLICY_NAME); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(3); - // Clusters on higher level has higher priority: [cluster2, cluster3, cluster4] - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, null, 100L, null, null); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster3, - DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(2), cluster4, - DiscoveryMechanism.Type.EDS, null, null, LRS_SERVER_INFO, 300L, null, outlierDetection); - assertThat( - GracefulSwitchLoadBalancerAccessor.getChildProvider(childLbConfig.lbConfig).getPolicyName()) - .isEqualTo("ring_hash_experimental"); // dominated by top-level cluster's config - RingHashConfig ringHashConfig = (RingHashConfig) - GracefulSwitchLoadBalancerAccessor.getChildConfig(childLbConfig.lbConfig); - assertThat(ringHashConfig.minRingSize).isEqualTo(100L); - assertThat(ringHashConfig.maxRingSize).isEqualTo(1000L); - } + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); - @Test - public void aggregateCluster_noNonAggregateClusterExits_returnErrorPicker() { - String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (EDS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - xdsClient.deliverResourceNotExist(cluster1); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); - assertPicker(pickerCaptor.getValue(), unavailable, null); - assertThat(childBalancers).isEmpty(); + assertThat(this.lastXdsConfig.getClusters()).containsKey(clusterName); + shutdownLoadBalancer(); + assertThat(this.lastXdsConfig.getClusters()).doesNotContainKey(clusterName); } @Test - public void aggregateCluster_descendantClustersRevoked() { - String cluster1 = "cluster-01.googleapis.com"; - String cluster2 = "cluster-02.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (EDS), cluster2 (logical DNS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Arrays.asList(cluster1, cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - CdsUpdate update1 = CdsUpdate.forEds(cluster1, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - CdsUpdate update2 = - CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(2); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster1, - DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, - null); - - // Revoke cluster1, should still be able to proceed with cluster2. - xdsClient.deliverResourceNotExist(cluster1); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - assertDiscoveryMechanism(Iterables.getOnlyElement(childLbConfig.discoveryMechanisms), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, - null); - verify(helper, never()).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), any(SubchannelPicker.class)); - - // All revoked. - xdsClient.deliverResourceNotExist(cluster2); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); - assertPicker(pickerCaptor.getValue(), unavailable, null); - assertThat(childBalancer.shutdown).isTrue(); - assertThat(childBalancers).isEmpty(); - } + public void discoverAggregateCluster_createsPriorityLbPolicy() { + CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); + lbRegistry.register(cdsLoadBalancerProvider); + loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); - @Test - public void aggregateCluster_rootClusterRevoked() { String cluster1 = "cluster-01.googleapis.com"; String cluster2 = "cluster-02.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (EDS), cluster2 (logical DNS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Arrays.asList(cluster1, cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - CdsUpdate update1 = CdsUpdate.forEds(cluster1, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - CdsUpdate update2 = - CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(2); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster1, - DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, - null); - - xdsClient.deliverResourceNotExist(CLUSTER); - assertThat(xdsClient.watchers.keySet()) - .containsExactly(CLUSTER); // subscription to all descendant clusters cancelled - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); - assertPicker(pickerCaptor.getValue(), unavailable, null); - assertThat(childBalancer.shutdown).isTrue(); - assertThat(childBalancers).isEmpty(); - } - - @Test - public void aggregateCluster_intermediateClusterChanges() { - String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - - // CLUSTER (aggr.) -> [cluster2 (aggr.)] - String cluster2 = "cluster-02.googleapis.com"; - update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster2); - - // cluster2 (aggr.) -> [cluster3 (EDS)] String cluster3 = "cluster-03.googleapis.com"; - CdsUpdate update2 = - CdsUpdate.forAggregate(cluster2, Collections.singletonList(cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster2, cluster3); - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); + String cluster4 = "cluster-04.googleapis.com"; + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + // CLUSTER (aggr.) -> [cluster1 (aggr.), cluster2 (logical DNS), cluster3 (EDS)] + CLUSTER, Cluster.newBuilder() + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster1) + .addClusters(cluster2) + .addClusters(cluster3) + .build()))) + .setLbPolicy(Cluster.LbPolicy.RING_HASH) + .build(), + // cluster1 (aggr.) -> [cluster3 (EDS), cluster4 (EDS)] + cluster1, Cluster.newBuilder() + .setName(cluster1) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster3) + .addClusters(cluster4) + .build()))) + .build(), + cluster2, Cluster.newBuilder() + .setName(cluster2) + .setType(Cluster.DiscoveryType.LOGICAL_DNS) + .setLoadAssignment(ClusterLoadAssignment.newBuilder() + .addEndpoints(LocalityLbEndpoints.newBuilder() + .addLbEndpoints(LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("dns.example.com") + .setPortValue(1111))))))) + .build(), + cluster3, EDS_CLUSTER.toBuilder() + .setName(cluster3) + .setCircuitBreakers(CircuitBreakers.newBuilder() + .addThresholds(CircuitBreakers.Thresholds.newBuilder() + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .build(), + cluster4, EDS_CLUSTER.toBuilder().setName(cluster4).build())); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, cluster3, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 100L, upstreamTlsContext, outlierDetection); - - // cluster2 revoked - xdsClient.deliverResourceNotExist(cluster2); - assertThat(xdsClient.watchers.keySet()) - .containsExactly(CLUSTER, cluster2); // cancelled subscription to cluster3 - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); - assertPicker(pickerCaptor.getValue(), unavailable, null); - assertThat(childBalancer.shutdown).isTrue(); - assertThat(childBalancers).isEmpty(); + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); + PriorityLoadBalancerProvider.PriorityLbConfig childLbConfig = + (PriorityLoadBalancerProvider.PriorityLbConfig) childBalancer.config; + assertThat(childLbConfig.priorities).hasSize(3); + assertThat(childLbConfig.priorities.get(0)).isEqualTo(cluster3); + assertThat(childLbConfig.priorities.get(1)).isEqualTo(cluster4); + assertThat(childLbConfig.priorities.get(2)).isEqualTo(cluster2); + assertThat(childLbConfig.childConfigs).hasSize(3); + PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig childConfig3 = + childLbConfig.childConfigs.get(cluster3); + assertThat( + GracefulSwitchLoadBalancerAccessor.getChildProvider(childConfig3.childConfig) + .getPolicyName()) + .isEqualTo("cds_experimental"); + PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig childConfig4 = + childLbConfig.childConfigs.get(cluster4); + assertThat( + GracefulSwitchLoadBalancerAccessor.getChildProvider(childConfig4.childConfig) + .getPolicyName()) + .isEqualTo("cds_experimental"); + PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig childConfig2 = + childLbConfig.childConfigs.get(cluster2); + assertThat( + GracefulSwitchLoadBalancerAccessor.getChildProvider(childConfig2.childConfig) + .getPolicyName()) + .isEqualTo("cds_experimental"); } @Test - public void aggregateCluster_withLoops() { - String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - - // CLUSTER (aggr.) -> [cluster2 (aggr.)] - String cluster2 = "cluster-02.googleapis.com"; - update = - CdsUpdate.forAggregate(cluster1, Collections.singletonList(cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - - // cluster2 (aggr.) -> [cluster3 (EDS), cluster1 (parent), cluster2 (self), cluster3 (dup)] - String cluster3 = "cluster-03.googleapis.com"; - CdsUpdate update2 = - CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3, cluster1, cluster2, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2, cluster3); - - reset(helper); - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: circular aggregate clusters directly under cluster-02.googleapis.com for root" - + " cluster cluster-foo.googleapis.com, named [cluster-01.googleapis.com," - + " cluster-02.googleapis.com]"); - assertPicker(pickerCaptor.getValue(), unavailable, null); - } + // Both priorities will get tried using real priority LB policy. + public void discoverAggregateCluster_testChildCdsLbPolicyParsing() { + lbRegistry.register(new PriorityLoadBalancerProvider()); + CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); + lbRegistry.register(cdsLoadBalancerProvider); + loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); - @Test - public void aggregateCluster_withLoops_afterEds() { String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - - // CLUSTER (aggr.) -> [cluster2 (aggr.)] String cluster2 = "cluster-02.googleapis.com"; - update = - CdsUpdate.forAggregate(cluster1, Collections.singletonList(cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - - String cluster3 = "cluster-03.googleapis.com"; - CdsUpdate update2 = - CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); - - // cluster2 (aggr.) -> [cluster3 (EDS)] - CdsUpdate update2a = - CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3, cluster1, cluster2, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2a); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2, cluster3); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: circular aggregate clusters directly under cluster-02.googleapis.com for root" - + " cluster cluster-foo.googleapis.com, named [cluster-01.googleapis.com," - + " cluster-02.googleapis.com]"); - assertPicker(pickerCaptor.getValue(), unavailable, null); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + // CLUSTER (aggr.) -> [cluster1 (EDS), cluster2 (EDS)] + CLUSTER, Cluster.newBuilder() + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster1) + .addClusters(cluster2) + .build()))) + .build(), + cluster1, EDS_CLUSTER.toBuilder().setName(cluster1).build(), + cluster2, EDS_CLUSTER.toBuilder().setName(cluster2).build())); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(2); + ClusterImplConfig cluster1ImplConfig = + (ClusterImplConfig) childBalancers.get(0).config; + assertThat(cluster1ImplConfig.cluster) + .isEqualTo("cluster-01.googleapis.com"); + assertThat(cluster1ImplConfig.edsServiceName) + .isEqualTo("backend-service-1.googleapis.com"); + ClusterImplConfig cluster2ImplConfig = + (ClusterImplConfig) childBalancers.get(1).config; + assertThat(cluster2ImplConfig.cluster) + .isEqualTo("cluster-02.googleapis.com"); + assertThat(cluster2ImplConfig.edsServiceName) + .isEqualTo("backend-service-1.googleapis.com"); } @Test - public void aggregateCluster_duplicateChildren() { - String cluster1 = "cluster-01.googleapis.com"; - String cluster2 = "cluster-02.googleapis.com"; - String cluster3 = "cluster-03.googleapis.com"; - String cluster4 = "cluster-04.googleapis.com"; - - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - - // cluster1 (aggr) -> [cluster3 (EDS), cluster2 (aggr), cluster4 (aggr)] - CdsUpdate update1 = - CdsUpdate.forAggregate(cluster1, Arrays.asList(cluster3, cluster2, cluster4, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - assertThat(xdsClient.watchers.keySet()).containsExactly( - cluster3, cluster4, cluster2, cluster1, CLUSTER); - xdsClient.watchers.values().forEach(list -> assertThat(list.size()).isEqualTo(1)); - - // cluster2 (agg) -> [cluster3 (EDS), cluster4 {agg}] with dups - CdsUpdate update2 = - CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3, cluster4, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - - // Define EDS cluster - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); - - // cluster4 (agg) -> [cluster3 (EDS)] with dups (3 copies) - CdsUpdate update4 = - CdsUpdate.forAggregate(cluster4, Arrays.asList(cluster3, cluster3, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster4, update4); - xdsClient.watchers.values().forEach(list -> assertThat(list.size()).isEqualTo(1)); - - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, cluster3, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 100L, upstreamTlsContext, outlierDetection); + public void aggregateCluster_noChildren() { + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + // CLUSTER (aggr.) -> [] + CLUSTER, Cluster.newBuilder() + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .build()))) + .build())); + startXdsDepManager(); + + verify(helper) + .updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + Status actualStatus = result.getStatus(); + assertThat(actualStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(actualStatus.getDescription()) + .contains("aggregate ClusterConfig.clusters must not be empty"); + assertThat(childBalancers).isEmpty(); } @Test - public void aggregateCluster_discoveryErrorBeforeChildLbCreated_returnErrorPicker() { + public void aggregateCluster_noNonAggregateClusterExits_returnErrorPicker() { + lbRegistry.register(new PriorityLoadBalancerProvider()); + CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); + lbRegistry.register(cdsLoadBalancerProvider); + loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); + String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - Status error = Status.RESOURCE_EXHAUSTED.withDescription("OOM"); - xdsClient.deliverError(error); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + // CLUSTER (aggr.) -> [cluster1 (missing)] + CLUSTER, Cluster.newBuilder() + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster1) + .build()))) + .setLbPolicy(Cluster.LbPolicy.RING_HASH) + .build())); + startXdsDepManager(); + verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status expectedError = Status.UNAVAILABLE.withDescription( - "Unable to load CDS cluster-foo.googleapis.com. xDS server returned: " - + "RESOURCE_EXHAUSTED: OOM"); - assertPicker(pickerCaptor.getValue(), expectedError, null); + String expectedDescription = "Error retrieving CDS resource " + cluster1 + ": NOT_FOUND. " + + "Details: Timed out waiting for resource " + cluster1 + " from xDS server nodeID: " + + NODE_ID; + Status status = Status.UNAVAILABLE.withDescription(expectedDescription); + assertPickerStatus(pickerCaptor.getValue(), status); assertThat(childBalancers).isEmpty(); } @Test - public void aggregateCluster_discoveryErrorAfterChildLbCreated_propagateToChildLb() { - String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (logical DNS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - CdsUpdate update1 = - CdsUpdate.forLogicalDns(cluster1, DNS_HOST_NAME, LRS_SERVER_INFO, 200L, null) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - FakeLoadBalancer childLb = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childLb.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - - Status error = Status.RESOURCE_EXHAUSTED.withDescription("OOM"); - xdsClient.deliverError(error); - assertThat(childLb.upstreamError.getCode()).isEqualTo(Status.Code.UNAVAILABLE); - assertThat(childLb.upstreamError.getDescription()).contains("RESOURCE_EXHAUSTED: OOM"); - assertThat(childLb.shutdown).isFalse(); // child LB may choose to keep working - } - - @Test - public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_returnErrorPicker() { - Status upstreamError = Status.UNAVAILABLE.withDescription("unreachable"); - loadBalancer.handleNameResolutionError(upstreamError); + public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_failingPicker() { + Status status = Status.UNAVAILABLE.withDescription("unreachable"); + loadBalancer.handleNameResolutionError(status); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), upstreamError, null); + assertPickerStatus(pickerCaptor.getValue(), status); } @Test public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThrough() { - CdsUpdate update = CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.shutdown).isFalse(); + loadBalancer.handleNameResolutionError(Status.UNAVAILABLE.withDescription("unreachable")); assertThat(childBalancer.upstreamError.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(childBalancer.upstreamError.getDescription()).isEqualTo("unreachable"); - verify(helper, never()).updateBalancingState( - any(ConnectivityState.class), any(SubchannelPicker.class)); + verify(helper).updateBalancingState( + eq(ConnectivityState.CONNECTING), any(SubchannelPicker.class)); } @Test public void unknownLbProvider() { - try { - xdsClient.deliverCdsUpdate(CLUSTER, - CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, - outlierDetection) - .lbPolicyConfig(ImmutableMap.of("unknownLb", ImmutableMap.of("foo", "bar"))).build()); - } catch (Exception e) { - assertThat(e).hasMessageThat().contains("unknownLb"); - return; - } - fail("Expected the unknown LB to cause an exception"); + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() + .addPolicies(Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setTypedConfig(Any.pack(TypedStruct.newBuilder() + .setTypeUrl("type.googleapis.com/unknownLb") + .setValue(Struct.getDefaultInstance()) + .build()))))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + Status actualStatus = result.getStatus(); + assertThat(actualStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(actualStatus.getDescription()).contains("Invalid LoadBalancingPolicy"); } @Test public void invalidLbConfig() { - try { - xdsClient.deliverCdsUpdate(CLUSTER, - CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, - outlierDetection).lbPolicyConfig( - ImmutableMap.of("ring_hash_experimental", ImmutableMap.of("minRingSize", "-1"))) + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() + .addPolicies(Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setTypedConfig(Any.pack(TypedStruct.newBuilder() + .setTypeUrl("type.googleapis.com/ring_hash_experimental") + .setValue(Struct.newBuilder() + .putFields("minRingSize", Value.newBuilder().setNumberValue(-1).build())) + .build()))))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + Status actualStatus = result.getStatus(); + assertThat(actualStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(actualStatus.getDescription()).contains("Invalid 'minRingSize'"); + } + + private void startXdsDepManager() { + startXdsDepManager(new CdsConfig(CLUSTER)); + } + + private void startXdsDepManager(final CdsConfig cdsConfig) { + xdsDepManager.start( + xdsConfig -> { + if (!xdsConfig.hasValue()) { + throw new AssertionError("" + xdsConfig.getStatus()); + } + this.lastXdsConfig = xdsConfig.getValue(); + if (loadBalancer == null) { + return; + } + loadBalancer.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(Collections.emptyList()) + .setAttributes(Attributes.newBuilder() + .set(XdsAttributes.XDS_CONFIG, xdsConfig.getValue()) + .set(XdsAttributes.XDS_CLUSTER_SUBSCRIPT_REGISTRY, xdsDepManager) + .build()) + .setLoadBalancingPolicyConfig(cdsConfig) .build()); - } catch (Exception e) { - assertThat(e).hasMessageThat().contains("Unable to parse"); - return; - } - fail("Expected the invalid config to cause an exception"); + }); + // trigger does not exist timer, so broken config is more obvious + fakeClock.forwardTime(10, TimeUnit.MINUTES); } - private static void assertPicker(SubchannelPicker picker, Status expectedStatus, - @Nullable Subchannel expectedSubchannel) { + private static void assertPickerStatus(SubchannelPicker picker, Status expectedStatus) { PickResult result = picker.pickSubchannel(mock(PickSubchannelArgs.class)); Status actualStatus = result.getStatus(); assertThat(actualStatus.getCode()).isEqualTo(expectedStatus.getCode()); assertThat(actualStatus.getDescription()).isEqualTo(expectedStatus.getDescription()); - if (actualStatus.isOk()) { - assertThat(result.getSubchannel()).isSameInstanceAs(expectedSubchannel); - } - } - - private static void assertDiscoveryMechanism(DiscoveryMechanism instance, String name, - DiscoveryMechanism.Type type, @Nullable String edsServiceName, @Nullable String dnsHostName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, @Nullable OutlierDetection outlierDetection) { - assertThat(instance.cluster).isEqualTo(name); - assertThat(instance.type).isEqualTo(type); - assertThat(instance.edsServiceName).isEqualTo(edsServiceName); - assertThat(instance.dnsHostName).isEqualTo(dnsHostName); - assertThat(instance.lrsServerInfo).isEqualTo(lrsServerInfo); - assertThat(instance.maxConcurrentRequests).isEqualTo(maxConcurrentRequests); - assertThat(instance.tlsContext).isEqualTo(tlsContext); - assertThat(instance.outlierDetection).isEqualTo(outlierDetection); } private final class FakeLoadBalancerProvider extends LoadBalancerProvider { @@ -778,8 +700,9 @@ private final class FakeLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { config = resolvedAddresses.getLoadBalancingPolicyConfig(); + return Status.OK; } @Override @@ -793,53 +716,4 @@ public void shutdown() { childBalancers.remove(this); } } - - private static final class FakeXdsClient extends XdsClient { - // watchers needs to support any non-cyclic shaped graphs - private final Map>> watchers = new HashMap<>(); - - @Override - @SuppressWarnings("unchecked") - public void watchXdsResource(XdsResourceType type, - String resourceName, - ResourceWatcher watcher, Executor syncContext) { - assertThat(type.typeName()).isEqualTo("CDS"); - watchers.computeIfAbsent(resourceName, k -> new ArrayList<>()) - .add((ResourceWatcher)watcher); - } - - @Override - public void cancelXdsResourceWatch(XdsResourceType type, - String resourceName, - ResourceWatcher watcher) { - assertThat(type.typeName()).isEqualTo("CDS"); - assertThat(watchers).containsKey(resourceName); - List> watcherList = watchers.get(resourceName); - assertThat(watcherList.remove(watcher)).isTrue(); - if (watcherList.isEmpty()) { - watchers.remove(resourceName); - } - } - - private void deliverCdsUpdate(String clusterName, CdsUpdate update) { - if (watchers.containsKey(clusterName)) { - List> resourceWatchers = - ImmutableList.copyOf(watchers.get(clusterName)); - resourceWatchers.forEach(w -> w.onChanged(update)); - } - } - - private void deliverResourceNotExist(String clusterName) { - if (watchers.containsKey(clusterName)) { - ImmutableList.copyOf(watchers.get(clusterName)) - .forEach(w -> w.onResourceDoesNotExist(clusterName)); - } - } - - private void deliverError(Status error) { - watchers.values().stream() - .flatMap(List::stream) - .forEach(w -> w.onError(error)); - } - } } diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index aaaed9554f4..9277675385a 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -18,6 +18,9 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static io.grpc.xds.ClusterImplLoadBalancer.ATTR_SUBCHANNEL_ADDRESS_NAME; +import static io.grpc.xds.XdsNameResolver.AUTO_HOST_REWRITE_KEY; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -35,6 +38,7 @@ import io.grpc.InsecureChannelCredentials; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.FixedResultPicker; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickDetailsConsumer; import io.grpc.LoadBalancer.PickResult; @@ -47,11 +51,12 @@ import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; import io.grpc.Metadata; +import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; -import io.grpc.internal.ObjectPool; +import io.grpc.internal.PickFirstLoadBalancerProvider; import io.grpc.internal.PickSubchannelArgsImpl; import io.grpc.protobuf.ProtoUtils; import io.grpc.testing.TestMethodDescriptors; @@ -63,6 +68,7 @@ import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; +import io.grpc.xds.client.BackendMetricPropagation; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.LoadReportClient; import io.grpc.xds.client.LoadStatsManager2; @@ -72,7 +78,9 @@ import io.grpc.xds.client.Stats.ClusterStats; import io.grpc.xds.client.Stats.UpstreamLocalityStats; import io.grpc.xds.client.XdsClient; +import io.grpc.xds.internal.XdsInternalAttributes; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.SecurityProtocolNegotiators; import io.grpc.xds.internal.security.SslContextProvider; import io.grpc.xds.internal.security.SslContextProviderSupplier; import java.net.SocketAddress; @@ -132,19 +140,6 @@ public void uncaughtException(Thread t, Throwable e) { private final LoadStatsManager2 loadStatsManager = new LoadStatsManager2(fakeClock.getStopwatchSupplier()); private final FakeXdsClient xdsClient = new FakeXdsClient(); - private final ObjectPool xdsClientPool = new ObjectPool() { - @Override - public XdsClient getObject() { - xdsClientRefs++; - return xdsClient; - } - - @Override - public XdsClient returnObject(Object object) { - xdsClientRefs--; - return null; - } - }; private final CallCounterProvider callCounterProvider = new CallCounterProvider() { @Override public AtomicLong getOrCreate(String cluster, @Nullable String edsServiceName) { @@ -185,14 +180,15 @@ public void handleResolvedAddresses_propagateToChildPolicy() { null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(downstreamBalancers); assertThat(Iterables.getOnlyElement(childBalancer.addresses)).isEqualTo(endpoint); assertThat(childBalancer.config).isSameInstanceAs(weightedTargetConfig); - assertThat(childBalancer.attributes.get(InternalXdsAttributes.XDS_CLIENT_POOL)) - .isSameInstanceAs(xdsClientPool); + assertThat(childBalancer.attributes.get(io.grpc.xds.XdsAttributes.XDS_CLIENT)) + .isSameInstanceAs(xdsClient); + assertThat(childBalancer.attributes.get(NameResolver.ATTR_BACKEND_SERVICE)).isEqualTo(CLUSTER); } /** @@ -212,7 +208,7 @@ public void handleResolvedAddresses_childPolicyChanges() { null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), configWithWeightedTarget); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(downstreamBalancers); @@ -227,7 +223,7 @@ public void handleResolvedAddresses_childPolicyChanges() { null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( wrrLocalityProvider, wrrLocalityConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); deliverAddressesAndConfig(Collections.singletonList(endpoint), configWithWrrLocality); childBalancer = Iterables.getOnlyElement(downstreamBalancers); assertThat(childBalancer.name).isEqualTo(XdsLbPolicies.WRR_LOCALITY_POLICY_NAME); @@ -253,7 +249,7 @@ public void nameResolutionError_afterChildPolicyInstantiated_propagateToDownstre null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(downstreamBalancers); @@ -266,7 +262,7 @@ public void nameResolutionError_afterChildPolicyInstantiated_propagateToDownstre } @Test - public void pick_addsLocalityLabel() { + public void pick_addsOptionalLabels() { LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); WeightedTargetConfig weightedTargetConfig = buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); @@ -274,12 +270,13 @@ public void pick_addsLocalityLabel() { null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); leafBalancer.createSubChannel(); FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); fakeSubchannel.setConnectedEagIndex(0); fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); assertThat(currentState).isEqualTo(ConnectivityState.READY); @@ -292,6 +289,31 @@ public void pick_addsLocalityLabel() { // The value will be determined by the parent policy, so can be different than the value used in // makeAddress() for the test. verify(detailsConsumer).addOptionalLabel("grpc.lb.locality", locality.toString()); + verify(detailsConsumer).addOptionalLabel("grpc.lb.backend_service", CLUSTER); + } + + @Test + public void pick_noResult_addsClusterLabel() { + LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); + WeightedTargetConfig weightedTargetConfig = + buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); + EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); + deliverAddressesAndConfig(Collections.singletonList(endpoint), config); + FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); + leafBalancer.deliverSubchannelState(PickResult.withNoResult(), ConnectivityState.CONNECTING); + assertThat(currentState).isEqualTo(ConnectivityState.CONNECTING); + + PickDetailsConsumer detailsConsumer = mock(PickDetailsConsumer.class); + pickSubchannelArgs = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, detailsConsumer); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getStatus().isOk()).isTrue(); + verify(detailsConsumer).addOptionalLabel("grpc.lb.backend_service", CLUSTER); } @Test @@ -303,12 +325,13 @@ public void recordLoadStats() { null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); Subchannel subchannel = leafBalancer.createSubChannel(); FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); fakeSubchannel.setConnectedEagIndex(0); fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); assertThat(currentState).isEqualTo(ConnectivityState.READY); @@ -381,9 +404,117 @@ public void recordLoadStats() { assertThat(clusterStats.upstreamLocalityStatsList()).isEmpty(); // no longer reported } - // TODO(dnvindhya): This test has been added as a fix to verify - // https://github.com/grpc/grpc-java/issues/11434. - // Once we update PickFirstLeafLoadBalancer as default LoadBalancer, update the test. + @Test + public void recordLoadStats_orcaLrsPropagationEnabled() { + boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; + LoadStatsManager2.isEnabledOrcaLrsPropagation = true; + BackendMetricPropagation backendMetricPropagation = BackendMetricPropagation.fromMetricSpecs( + Arrays.asList("application_utilization", "cpu_utilization", "named_metrics.named1")); + LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); + WeightedTargetConfig weightedTargetConfig = + buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), backendMetricPropagation); + EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); + deliverAddressesAndConfig(Collections.singletonList(endpoint), config); + FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); + Subchannel subchannel = leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getStatus().isOk()).isTrue(); + ClientStreamTracer streamTracer = result.getStreamTracerFactory().newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); + Metadata trailersWithOrcaLoadReport = new Metadata(); + trailersWithOrcaLoadReport.put(ORCA_ENDPOINT_LOAD_METRICS_KEY, + OrcaLoadReport.newBuilder() + .setApplicationUtilization(1.414) + .setCpuUtilization(0.5) + .setMemUtilization(0.034) + .putNamedMetrics("named1", 3.14159) + .putNamedMetrics("named2", -1.618).build()); + streamTracer.inboundTrailers(trailersWithOrcaLoadReport); + streamTracer.streamClosed(Status.OK); + ClusterStats clusterStats = + Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); + UpstreamLocalityStats localityStats = + Iterables.getOnlyElement(clusterStats.upstreamLocalityStatsList()); + + assertThat(localityStats.loadMetricStatsMap()).containsKey("application_utilization"); + assertThat(localityStats.loadMetricStatsMap().get("application_utilization").totalMetricValue()) + .isWithin(TOLERANCE).of(1.414); + assertThat(localityStats.loadMetricStatsMap()).containsKey("cpu_utilization"); + assertThat(localityStats.loadMetricStatsMap().get("cpu_utilization").totalMetricValue()) + .isWithin(TOLERANCE).of(0.5); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("mem_utilization"); + assertThat(localityStats.loadMetricStatsMap()).containsKey("named_metrics.named1"); + assertThat(localityStats.loadMetricStatsMap().get("named_metrics.named1").totalMetricValue()) + .isWithin(TOLERANCE).of(3.14159); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("named_metrics.named2"); + subchannel.shutdown(); + LoadStatsManager2.isEnabledOrcaLrsPropagation = originalVal; + } + + @Test + public void recordLoadStats_orcaLrsPropagationDisabled() { + boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; + LoadStatsManager2.isEnabledOrcaLrsPropagation = false; + BackendMetricPropagation backendMetricPropagation = BackendMetricPropagation.fromMetricSpecs( + Arrays.asList("application_utilization", "cpu_utilization", "named_metrics.named1")); + LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); + WeightedTargetConfig weightedTargetConfig = + buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), backendMetricPropagation); + EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); + deliverAddressesAndConfig(Collections.singletonList(endpoint), config); + FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); + Subchannel subchannel = leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getStatus().isOk()).isTrue(); + ClientStreamTracer streamTracer = result.getStreamTracerFactory().newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); + Metadata trailersWithOrcaLoadReport = new Metadata(); + trailersWithOrcaLoadReport.put(ORCA_ENDPOINT_LOAD_METRICS_KEY, + OrcaLoadReport.newBuilder() + .setApplicationUtilization(1.414) + .setCpuUtilization(0.5) + .setMemUtilization(0.034) + .putNamedMetrics("named1", 3.14159) + .putNamedMetrics("named2", -1.618).build()); + streamTracer.inboundTrailers(trailersWithOrcaLoadReport); + streamTracer.streamClosed(Status.OK); + ClusterStats clusterStats = + Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); + UpstreamLocalityStats localityStats = + Iterables.getOnlyElement(clusterStats.upstreamLocalityStatsList()); + + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("application_utilization"); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("cpu_utilization"); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("mem_utilization"); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("named_metrics.named1"); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("named_metrics.named2"); + assertThat(localityStats.loadMetricStatsMap().containsKey("named1")).isTrue(); + assertThat(localityStats.loadMetricStatsMap().containsKey("named2")).isTrue(); + subchannel.shutdown(); + LoadStatsManager2.isEnabledOrcaLrsPropagation = originalVal; + } + + // Verifies https://github.com/grpc/grpc-java/issues/11434. @Test public void pickFirstLoadReport_onUpdateAddress() { Locality locality1 = @@ -399,7 +530,7 @@ public void pickFirstLoadReport_onUpdateAddress() { null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(pickFirstProvider, pickFirstConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality1); EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr2", locality2); deliverAddressesAndConfig(Arrays.asList(endpoint1, endpoint2), config); @@ -407,6 +538,7 @@ public void pickFirstLoadReport_onUpdateAddress() { // Leaf balancer is created by Pick First. Get FakeSubchannel created to update attributes // A real subchannel would get these attributes from the connected address's EAG locality. FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); fakeSubchannel.setConnectedEagIndex(0); fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); assertThat(currentState).isEqualTo(ConnectivityState.READY); @@ -431,8 +563,17 @@ public void pickFirstLoadReport_onUpdateAddress() { fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); // Faksubchannel mimics update address and returns different locality - fakeSubchannel.setConnectedEagIndex(1); - fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + fakeSubchannel.updateState(ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription("Try second address instead"))); + fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + } else { + fakeSubchannel.setConnectedEagIndex(1); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + } result = currentPicker.pickSubchannel(pickSubchannelArgs); assertThat(result.getStatus().isOk()).isTrue(); ClientStreamTracer streamTracer2 = result.getStreamTracerFactory().newClientStreamTracer( @@ -479,7 +620,7 @@ public void dropRpcsWithRespectToLbConfigDropCategories() { null, Collections.singletonList(DropOverload.create("throttle", 500_000)), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); when(mockRandom.nextInt(anyInt())).thenReturn(499_999, 999_999, 1_000_000); @@ -490,6 +631,7 @@ public void dropRpcsWithRespectToLbConfigDropCategories() { .isEqualTo(endpoint.getAddresses()); leafBalancer.createSubChannel(); FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); fakeSubchannel.setConnectedEagIndex(0); fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); @@ -512,13 +654,13 @@ public void dropRpcsWithRespectToLbConfigDropCategories() { Collections.singletonList(DropOverload.create("lb", 1_000_000)), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.singletonList(endpoint)) .setAttributes( Attributes.newBuilder() - .set(InternalXdsAttributes.XDS_CLIENT_POOL, xdsClientPool) + .set(io.grpc.xds.XdsAttributes.XDS_CLIENT, xdsClient) .build()) .setLoadBalancingPolicyConfig(config) .build()); @@ -561,7 +703,7 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu maxConcurrentRequests, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); assertThat(downstreamBalancers).hasSize(1); // one leaf balancer @@ -571,6 +713,7 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu .isEqualTo(endpoint.getAddresses()); leafBalancer.createSubChannel(); FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); fakeSubchannel.setConnectedEagIndex(0); fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); assertThat(currentState).isEqualTo(ConnectivityState.READY); @@ -594,7 +737,7 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu assertThat(result.getStatus().isOk()).isFalse(); assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(result.getStatus().getDescription()) - .isEqualTo("Cluster max concurrent requests limit exceeded"); + .isEqualTo("Cluster max concurrent requests limit of 100 exceeded"); assertThat(clusterStats.totalDroppedRequests()).isEqualTo(1L); } else { assertThat(result.getStatus().isOk()).isTrue(); @@ -607,7 +750,7 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu maxConcurrentRequests, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); result = currentPicker.pickSubchannel(pickSubchannelArgs); @@ -625,7 +768,7 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu assertThat(result.getStatus().isOk()).isFalse(); assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(result.getStatus().getDescription()) - .isEqualTo("Cluster max concurrent requests limit exceeded"); + .isEqualTo("Cluster max concurrent requests limit of 101 exceeded"); assertThat(clusterStats.totalDroppedRequests()).isEqualTo(1L); } else { assertThat(result.getStatus().isOk()).isTrue(); @@ -655,7 +798,7 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue( null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); assertThat(downstreamBalancers).hasSize(1); // one leaf balancer @@ -665,6 +808,7 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue( .isEqualTo(endpoint.getAddresses()); leafBalancer.createSubChannel(); FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); fakeSubchannel.setConnectedEagIndex(0); fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); assertThat(currentState).isEqualTo(ConnectivityState.READY); @@ -688,7 +832,7 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue( assertThat(result.getStatus().isOk()).isFalse(); assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(result.getStatus().getDescription()) - .isEqualTo("Cluster max concurrent requests limit exceeded"); + .isEqualTo("Cluster max concurrent requests limit of 1024 exceeded"); assertThat(clusterStats.totalDroppedRequests()).isEqualTo(1L); } else { assertThat(result.getStatus().isOk()).isTrue(); @@ -705,7 +849,7 @@ public void endpointAddressesAttachedWithClusterName() { null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); // One locality with two endpoints. EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality); EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr2", locality); @@ -721,22 +865,123 @@ public void endpointAddressesAttachedWithClusterName() { .build(); Subchannel subchannel = leafBalancer.helper.createSubchannel(args); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_CLUSTER_NAME)) + assertThat(eag.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_CLUSTER_NAME)) .isEqualTo(CLUSTER); } // An address update should also retain the cluster attribute. subchannel.updateAddresses(leafBalancer.addresses); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_CLUSTER_NAME)) + assertThat(eag.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_CLUSTER_NAME)) .isEqualTo(CLUSTER); } } + @Test + public void + endpointsWithAuthorityHostname_autoHostRewriteEnabled_pickResultHasAuthorityHostname() { + System.setProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", "true"); + try { + LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); + WeightedTargetConfig weightedTargetConfig = + buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); + EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality, + "authority-host-name"); + deliverAddressesAndConfig(Arrays.asList(endpoint1), config); + assertThat(downstreamBalancers).hasSize(1); // one leaf balancer + FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); + assertThat(leafBalancer.name).isEqualTo("round_robin"); + + // Simulates leaf load balancer creating subchannels. + CreateSubchannelArgs args = + CreateSubchannelArgs.newBuilder() + .setAddresses(leafBalancer.addresses) + .build(); + Subchannel subchannel = leafBalancer.helper.createSubchannel(args); + subchannel.start(infoObject -> { + if (infoObject.getState() == ConnectivityState.READY) { + helper.updateBalancingState( + ConnectivityState.READY, + new FixedResultPicker(PickResult.withSubchannel(subchannel))); + } + }); + assertThat(subchannel.getAttributes().get(ATTR_SUBCHANNEL_ADDRESS_NAME)).isEqualTo( + "authority-host-name"); + for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { + assertThat(eag.getAttributes().get(XdsInternalAttributes.ATTR_ADDRESS_NAME)) + .isEqualTo("authority-host-name"); + } + + leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + assertThat(currentState).isEqualTo(ConnectivityState.READY); + PickDetailsConsumer detailsConsumer = mock(PickDetailsConsumer.class); + pickSubchannelArgs = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), + CallOptions.DEFAULT.withOption(AUTO_HOST_REWRITE_KEY, true), detailsConsumer); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getAuthorityOverride()).isEqualTo("authority-host-name"); + } finally { + System.clearProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE"); + } + } + + @Test + public void + endpointWithAuthorityHostname_autoHostRewriteNotEnabled_pickResultNoAuthorityHostname() { + LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); + WeightedTargetConfig weightedTargetConfig = + buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); + EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality, + "authority-host-name"); + deliverAddressesAndConfig(Arrays.asList(endpoint1), config); + assertThat(downstreamBalancers).hasSize(1); // one leaf balancer + FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); + assertThat(leafBalancer.name).isEqualTo("round_robin"); + + // Simulates leaf load balancer creating subchannels. + CreateSubchannelArgs args = + CreateSubchannelArgs.newBuilder() + .setAddresses(leafBalancer.addresses) + .build(); + Subchannel subchannel = leafBalancer.helper.createSubchannel(args); + subchannel.start(infoObject -> { + if (infoObject.getState() == ConnectivityState.READY) { + helper.updateBalancingState( + ConnectivityState.READY, + new FixedResultPicker(PickResult.withSubchannel(subchannel))); + } + }); + // Sub Channel wrapper args won't have the address name although addresses will. + assertThat(subchannel.getAttributes().get(ATTR_SUBCHANNEL_ADDRESS_NAME)).isNull(); + for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { + assertThat(eag.getAttributes().get(XdsInternalAttributes.ATTR_ADDRESS_NAME)) + .isEqualTo("authority-host-name"); + } + + leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + assertThat(currentState).isEqualTo(ConnectivityState.READY); + PickDetailsConsumer detailsConsumer = mock(PickDetailsConsumer.class); + pickSubchannelArgs = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, detailsConsumer); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getAuthorityOverride()).isNull(); + } + @Test public void endpointAddressesAttachedWithTlsConfig_securityEnabledByDefault() { UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe", true); LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); WeightedTargetConfig weightedTargetConfig = buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); @@ -744,7 +989,7 @@ public void endpointAddressesAttachedWithTlsConfig_securityEnabledByDefault() { null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - upstreamTlsContext, Collections.emptyMap()); + upstreamTlsContext, Collections.emptyMap(), null); // One locality with two endpoints. EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality); EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr2", locality); @@ -760,7 +1005,7 @@ public void endpointAddressesAttachedWithTlsConfig_securityEnabledByDefault() { Subchannel subchannel = leafBalancer.helper.createSubchannel(args); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { SslContextProviderSupplier supplier = - eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); + eag.getAttributes().get(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); assertThat(supplier.getTlsContext()).isEqualTo(upstreamTlsContext); } @@ -769,36 +1014,37 @@ public void endpointAddressesAttachedWithTlsConfig_securityEnabledByDefault() { null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - null, Collections.emptyMap()); + null, Collections.emptyMap(), null); deliverAddressesAndConfig(Arrays.asList(endpoint1, endpoint2), config); assertThat(Iterables.getOnlyElement(downstreamBalancers)).isSameInstanceAs(leafBalancer); subchannel = leafBalancer.helper.createSubchannel(args); // creates new connections for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER)) + assertThat( + eag.getAttributes().get(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER)) .isNull(); } // Config with a new UpstreamTlsContext. - upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe1", true); + upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe1", true); config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( weightedTargetProvider, weightedTargetConfig), - upstreamTlsContext, Collections.emptyMap()); + upstreamTlsContext, Collections.emptyMap(), null); deliverAddressesAndConfig(Arrays.asList(endpoint1, endpoint2), config); assertThat(Iterables.getOnlyElement(downstreamBalancers)).isSameInstanceAs(leafBalancer); subchannel = leafBalancer.helper.createSubchannel(args); // creates new connections for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { SslContextProviderSupplier supplier = - eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); + eag.getAttributes().get(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); assertThat(supplier.isShutdown()).isFalse(); assertThat(supplier.getTlsContext()).isEqualTo(upstreamTlsContext); } loadBalancer.shutdown(); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { SslContextProviderSupplier supplier = - eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); + eag.getAttributes().get(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); assertThat(supplier.isShutdown()).isTrue(); } loadBalancer = null; @@ -811,8 +1057,8 @@ private void deliverAddressesAndConfig(List addresses, .setAddresses(addresses) .setAttributes( Attributes.newBuilder() - .set(InternalXdsAttributes.XDS_CLIENT_POOL, xdsClientPool) - .set(InternalXdsAttributes.CALL_COUNTER_PROVIDER, callCounterProvider) + .set(io.grpc.xds.XdsAttributes.XDS_CLIENT, xdsClient) + .set(io.grpc.xds.XdsAttributes.CALL_COUNTER_PROVIDER, callCounterProvider) .build()) .setLoadBalancingPolicyConfig(config) .build()); @@ -833,6 +1079,11 @@ private WeightedTargetConfig buildWeightedTargetConfig(Map lo * Create a locality-labeled address. */ private static EquivalentAddressGroup makeAddress(final String name, Locality locality) { + return makeAddress(name, locality, null); + } + + private static EquivalentAddressGroup makeAddress(final String name, Locality locality, + String authorityHostname) { class FakeSocketAddress extends SocketAddress { private final String name; @@ -863,12 +1114,15 @@ public String toString() { } } + Attributes.Builder attributes = Attributes.newBuilder() + .set(io.grpc.xds.XdsAttributes.ATTR_LOCALITY, locality) + // Unique but arbitrary string + .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, locality.toString()); + if (authorityHostname != null) { + attributes.set(XdsInternalAttributes.ATTR_ADDRESS_NAME, authorityHostname); + } EquivalentAddressGroup eag = new EquivalentAddressGroup(new FakeSocketAddress(name), - Attributes.newBuilder() - .set(InternalXdsAttributes.ATTR_LOCALITY, locality) - // Unique but arbitrary string - .set(InternalXdsAttributes.ATTR_LOCALITY_NAME, locality.toString()) - .build()); + attributes.build()); return AddressFilter.setPathFilter(eag, Collections.singletonList(locality.toString())); } @@ -933,6 +1187,20 @@ public void shutdown() { downstreamBalancers.remove(this); } + void deliverSubchannelState(final Subchannel subchannel, ConnectivityState state) { + deliverSubchannelState(PickResult.withSubchannel(subchannel), state); + } + + void deliverSubchannelState(final PickResult result, ConnectivityState state) { + SubchannelPicker picker = new SubchannelPicker() { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + return result; + } + }; + helper.updateBalancingState(state, picker); + } + Subchannel createSubChannel() { Subchannel subchannel = helper.createSubchannel( CreateSubchannelArgs.newBuilder().setAddresses(addresses).build()); @@ -943,6 +1211,7 @@ Subchannel createSubChannel() { new FixedResultPicker(PickResult.withSubchannel(subchannel))); } }); + subchannel.requestConnection(); return subchannel; } } @@ -989,6 +1258,8 @@ private static final class FakeSubchannel extends Subchannel { private final Attributes attrs; private SubchannelStateListener listener; private Attributes connectedAttributes; + private ConnectivityStateInfo state = ConnectivityStateInfo.forNonError(ConnectivityState.IDLE); + private boolean connectionRequested; private FakeSubchannel(List eags, Attributes attrs) { this.eags = eags; @@ -1006,6 +1277,9 @@ public void shutdown() { @Override public void requestConnection() { + if (state.getState() == ConnectivityState.IDLE) { + this.connectionRequested = true; + } } @Override @@ -1028,6 +1302,26 @@ public Attributes getConnectedAddressAttributes() { } public void updateState(ConnectivityStateInfo newState) { + switch (newState.getState()) { + case IDLE: + assertThat(state.getState()).isEqualTo(ConnectivityState.READY); + break; + case CONNECTING: + assertThat(state.getState()) + .isIn(Arrays.asList(ConnectivityState.IDLE, ConnectivityState.TRANSIENT_FAILURE)); + if (state.getState() == ConnectivityState.IDLE) { + assertWithMessage("Connection requested").that(this.connectionRequested).isTrue(); + this.connectionRequested = false; + } + break; + case READY: + case TRANSIENT_FAILURE: + assertThat(state.getState()).isEqualTo(ConnectivityState.CONNECTING); + break; + default: + break; + } + this.state = newState; listener.onSubchannelState(newState); } @@ -1037,6 +1331,7 @@ public void setConnectedEagIndex(int eagIndex) { } private final class FakeXdsClient extends XdsClient { + @Override public ClusterDropStats addClusterDropStats( ServerInfo lrsServerInfo, String clusterName, @Nullable String edsServiceName) { @@ -1046,8 +1341,9 @@ public ClusterDropStats addClusterDropStats( @Override public ClusterLocalityStats addClusterLocalityStats( ServerInfo lrsServerInfo, String clusterName, @Nullable String edsServiceName, - Locality locality) { - return loadStatsManager.getClusterLocalityStats(clusterName, edsServiceName, locality); + Locality locality, BackendMetricPropagation backendMetricPropagation) { + return loadStatsManager.getClusterLocalityStats( + clusterName, edsServiceName, locality, backendMetricPropagation); } @Override diff --git a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java index aa0e205dd8f..8856efd685f 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java @@ -118,7 +118,7 @@ public void tearDown() { } @Test - public void handleResolvedAddressesUpdatesChannelPicker() { + public void acceptResolvedAddressesUpdatesChannelPicker() { deliverResolvedAddresses(ImmutableMap.of("childA", "policy_a", "childB", "policy_b")); verify(helper, atLeastOnce()).updateBalancingState( diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerProviderTest.java deleted file mode 100644 index a201ecfaa4b..00000000000 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerProviderTest.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.truth.Truth.assertThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import io.grpc.ChannelLogger; -import io.grpc.LoadBalancer; -import io.grpc.LoadBalancer.Helper; -import io.grpc.LoadBalancerProvider; -import io.grpc.LoadBalancerRegistry; -import io.grpc.NameResolver; -import io.grpc.NameResolver.ServiceConfigParser; -import io.grpc.NameResolverRegistry; -import io.grpc.SynchronizationContext; -import io.grpc.internal.FakeClock; -import io.grpc.internal.GrpcUtil; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for {@link ClusterResolverLoadBalancerProvider}. */ -@RunWith(JUnit4.class) -public class ClusterResolverLoadBalancerProviderTest { - - @Test - public void provided() { - LoadBalancerProvider provider = - LoadBalancerRegistry.getDefaultRegistry().getProvider( - XdsLbPolicies.CLUSTER_RESOLVER_POLICY_NAME); - assertThat(provider).isInstanceOf(ClusterResolverLoadBalancerProvider.class); - } - - @Test - public void providesLoadBalancer() { - Helper helper = mock(Helper.class); - - SynchronizationContext syncContext = new SynchronizationContext( - new Thread.UncaughtExceptionHandler() { - @Override - public void uncaughtException(Thread t, Throwable e) { - throw new AssertionError(e); - } - }); - FakeClock fakeClock = new FakeClock(); - NameResolverRegistry nsRegistry = new NameResolverRegistry(); - NameResolver.Args args = NameResolver.Args.newBuilder() - .setDefaultPort(8080) - .setProxyDetector(GrpcUtil.NOOP_PROXY_DETECTOR) - .setSynchronizationContext(syncContext) - .setServiceConfigParser(mock(ServiceConfigParser.class)) - .setChannelLogger(mock(ChannelLogger.class)) - .build(); - when(helper.getNameResolverRegistry()).thenReturn(nsRegistry); - when(helper.getNameResolverArgs()).thenReturn(args); - when(helper.getSynchronizationContext()).thenReturn(syncContext); - when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService()); - when(helper.getAuthority()).thenReturn("api.google.com"); - LoadBalancerProvider provider = new ClusterResolverLoadBalancerProvider(); - LoadBalancer loadBalancer = provider.newLoadBalancer(helper); - assertThat(loadBalancer).isInstanceOf(ClusterResolverLoadBalancer.class); - } -} diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index d3c6bc8f9a0..c6e5db08526 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -21,22 +21,45 @@ import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME; import static io.grpc.xds.XdsLbPolicies.WRR_LOCALITY_POLICY_NAME; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; import static java.util.stream.Collectors.toList; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.protobuf.Any; +import com.google.protobuf.Duration; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.OutlierDetection; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; +import io.envoyproxy.envoy.config.core.v3.ConfigSource; +import io.envoyproxy.envoy.config.core.v3.HealthStatus; +import io.envoyproxy.envoy.config.core.v3.Locality; +import io.envoyproxy.envoy.config.core.v3.Metadata; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.core.v3.TransportSocket; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; +import io.envoyproxy.envoy.extensions.transport_sockets.http_11_proxy.v3.Http11ProxyUpstreamTransport; import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; -import io.grpc.InsecureChannelCredentials; +import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; @@ -51,49 +74,36 @@ import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; import io.grpc.Status; -import io.grpc.Status.Code; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; -import io.grpc.internal.BackoffPolicy; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.FakeClock; -import io.grpc.internal.FakeClock.ScheduledTask; import io.grpc.internal.GrpcUtil; -import io.grpc.internal.ObjectPool; -import io.grpc.util.GracefulSwitchLoadBalancer; +import io.grpc.testing.GrpcCleanupRule; import io.grpc.util.GracefulSwitchLoadBalancerAccessor; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; import io.grpc.util.OutlierDetectionLoadBalancerProvider; +import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig.DiscoveryMechanism; import io.grpc.xds.Endpoints.DropOverload; -import io.grpc.xds.Endpoints.LbEndpoint; -import io.grpc.xds.Endpoints.LocalityLbEndpoints; -import io.grpc.xds.EnvoyServerProtoData.FailurePercentageEjection; -import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; -import io.grpc.xds.EnvoyServerProtoData.SuccessRateEjection; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import io.grpc.xds.WrrLocalityLoadBalancer.WrrLocalityConfig; -import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.client.BackendMetricPropagation; import io.grpc.xds.client.Bootstrapper.ServerInfo; -import io.grpc.xds.client.Locality; +import io.grpc.xds.client.LoadStatsManager2; import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsResourceType; -import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; -import java.net.SocketAddress; +import io.grpc.xds.internal.XdsInternalAttributes; +import java.net.InetSocketAddress; import java.net.URI; -import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; +import java.util.Iterator; import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; import org.junit.After; @@ -104,9 +114,7 @@ import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Captor; -import org.mockito.InOrder; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -114,39 +122,43 @@ @RunWith(JUnit4.class) public class ClusterResolverLoadBalancerTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + @Rule + public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); - private static final String AUTHORITY = "api.google.com"; - private static final String CLUSTER1 = "cluster-foo.googleapis.com"; - private static final String CLUSTER2 = "cluster-bar.googleapis.com"; - private static final String CLUSTER_DNS = "cluster-dns.googleapis.com"; - private static final String EDS_SERVICE_NAME1 = "backend-service-foo.googleapis.com"; - private static final String EDS_SERVICE_NAME2 = "backend-service-bar.googleapis.com"; + private static final String SERVER_NAME = "example.com"; + private static final String CLUSTER = "cluster-foo.googleapis.com"; + private static final String EDS_SERVICE_NAME = "backend-service-foo.googleapis.com"; private static final String DNS_HOST_NAME = "dns-service.googleapis.com"; - private static final ServerInfo LRS_SERVER_INFO = - ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); - private final Locality locality1 = - Locality.create("test-region-1", "test-zone-1", "test-subzone-1"); - private final Locality locality2 = - Locality.create("test-region-2", "test-zone-2", "test-subzone-2"); - private final Locality locality3 = - Locality.create("test-region-3", "test-zone-3", "test-subzone-3"); - private final UpstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); - private final OutlierDetection outlierDetection = OutlierDetection.create( - 100L, 100L, 100L, 100, SuccessRateEjection.create(100, 100, 100, 100), - FailurePercentageEjection.create(100, 100, 100, 100)); - private final DiscoveryMechanism edsDiscoveryMechanism1 = - DiscoveryMechanism.forEds(CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, tlsContext, - Collections.emptyMap(), null); - private final DiscoveryMechanism edsDiscoveryMechanism2 = - DiscoveryMechanism.forEds(CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_INFO, 200L, tlsContext, - Collections.emptyMap(), null); - private final DiscoveryMechanism edsDiscoveryMechanismWithOutlierDetection = - DiscoveryMechanism.forEds(CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, tlsContext, - Collections.emptyMap(), outlierDetection); - private final DiscoveryMechanism logicalDnsDiscoveryMechanism = - DiscoveryMechanism.forLogicalDns(CLUSTER_DNS, DNS_HOST_NAME, LRS_SERVER_INFO, 300L, null, - Collections.emptyMap()); + private static final Cluster EDS_CLUSTER = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .build(); + private static final Cluster LOGICAL_DNS_CLUSTER = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.LOGICAL_DNS) + .setLoadAssignment(ClusterLoadAssignment.newBuilder() + .addEndpoints(LocalityLbEndpoints.newBuilder() + .addLbEndpoints(newSocketLbEndpoint(DNS_HOST_NAME, 9000)))) + .build(); + private static final Locality LOCALITY1 = Locality.newBuilder() + .setRegion("test-region-1") + .setZone("test-zone-1") + .setSubZone("test-subzone-1") + .build(); + private static final Locality LOCALITY2 = Locality.newBuilder() + .setRegion("test-region-2") + .setZone("test-zone-2") + .setSubZone("test-subzone-2") + .build(); + private static final Locality LOCALITY3 = Locality.newBuilder() + .setRegion("test-region-3") + .setZone("test-zone-3") + .setSubZone("test-subzone-3") + .build(); private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @@ -158,49 +170,31 @@ public void uncaughtException(Thread t, Throwable e) { private final FakeClock fakeClock = new FakeClock(); private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); private final NameResolverRegistry nsRegistry = new NameResolverRegistry(); - private final Object roundRobin = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( - new FakeLoadBalancerProvider("wrr_locality_experimental"), new WrrLocalityConfig( - GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( - new FakeLoadBalancerProvider("round_robin"), null))); - private final Object ringHash = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( - new FakeLoadBalancerProvider("ring_hash_experimental"), new RingHashConfig(10L, 100L)); - private final Object leastRequest = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( - new FakeLoadBalancerProvider("wrr_locality_experimental"), new WrrLocalityConfig( - GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( - new FakeLoadBalancerProvider("least_request_experimental"), - new LeastRequestConfig(3)))); private final List childBalancers = new ArrayList<>(); private final List resolvers = new ArrayList<>(); - private final FakeXdsClient xdsClient = new FakeXdsClient(); - private final ObjectPool xdsClientPool = new ObjectPool() { - @Override - public XdsClient getObject() { - xdsClientRefs++; - return xdsClient; - } - - @Override - public XdsClient returnObject(Object object) { - xdsClientRefs--; - return null; - } - }; - + private final XdsTestControlPlaneService controlPlaneService = new XdsTestControlPlaneService(); + private final XdsClient xdsClient = XdsTestUtils.createXdsClient( + Arrays.asList("control-plane.example.com"), + serverInfo -> new GrpcXdsTransportFactory.GrpcXdsTransport( + InProcessChannelBuilder + .forName(serverInfo.target()) + .directExecutor() + .build()), + fakeClock); + + + private XdsDependencyManager xdsDepManager; @Mock private Helper helper; - @Mock - private BackoffPolicy.Provider backoffPolicyProvider; - @Mock - private BackoffPolicy backoffPolicy1; - @Mock - private BackoffPolicy backoffPolicy2; @Captor private ArgumentCaptor pickerCaptor; - private int xdsClientRefs; - private ClusterResolverLoadBalancer loadBalancer; + private CdsLoadBalancer2 loadBalancer; + private boolean originalIsEnabledXdsHttpConnect; @Before - public void setUp() throws URISyntaxException { + public void setUp() throws Exception { + lbRegistry.register(new RingHashLoadBalancerProvider()); + lbRegistry.register(new WrrLocalityLoadBalancerProvider()); lbRegistry.register(new FakeLoadBalancerProvider(PRIORITY_POLICY_NAME)); lbRegistry.register(new FakeLoadBalancerProvider(CLUSTER_IMPL_POLICY_NAME)); lbRegistry.register(new FakeLoadBalancerProvider(WEIGHTED_TARGET_POLICY_NAME)); @@ -213,76 +207,129 @@ public void setUp() throws URISyntaxException { .setSynchronizationContext(syncContext) .setServiceConfigParser(mock(ServiceConfigParser.class)) .setChannelLogger(mock(ChannelLogger.class)) + .setScheduledExecutorService(fakeClock.getScheduledExecutorService()) + .setNameResolverRegistry(nsRegistry) .build(); + + xdsDepManager = new XdsDependencyManager( + xdsClient, + syncContext, + SERVER_NAME, + SERVER_NAME, + args); + + cleanupRule.register(InProcessServerBuilder + .forName("control-plane.example.com") + .addService(controlPlaneService) + .directExecutor() + .build() + .start()); + + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, ImmutableMap.of( + SERVER_NAME, ControlPlaneRule.buildClientListener(SERVER_NAME, "my-route"))); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_RDS, ImmutableMap.of( + "my-route", XdsTestUtils.buildRouteConfiguration(SERVER_NAME, "my-route", CLUSTER))); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, EDS_CLUSTER)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, ControlPlaneRule.buildClusterLoadAssignment( + "127.0.0.1", "", 8080, EDS_SERVICE_NAME))); + nsRegistry.register(new FakeNameResolverProvider()); - when(helper.getNameResolverRegistry()).thenReturn(nsRegistry); - when(helper.getNameResolverArgs()).thenReturn(args); - when(helper.getSynchronizationContext()).thenReturn(syncContext); - when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService()); - when(helper.getAuthority()).thenReturn(AUTHORITY); - when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2); - when(backoffPolicy1.nextBackoffNanos()) - .thenReturn(TimeUnit.SECONDS.toNanos(1L), TimeUnit.SECONDS.toNanos(10L)); - when(backoffPolicy2.nextBackoffNanos()) - .thenReturn(TimeUnit.SECONDS.toNanos(5L), TimeUnit.SECONDS.toNanos(50L)); - loadBalancer = new ClusterResolverLoadBalancer(helper, lbRegistry, backoffPolicyProvider); + when(helper.getAuthority()).thenReturn("api.google.com"); + doAnswer((inv) -> { + xdsDepManager.requestReresolution(); + return null; + }).when(helper).refreshNameResolution(); + loadBalancer = new CdsLoadBalancer2(helper, lbRegistry); + + originalIsEnabledXdsHttpConnect = XdsClusterResource.isEnabledXdsHttpConnect; } @After - public void tearDown() { + public void tearDown() throws Exception { + XdsClusterResource.isEnabledXdsHttpConnect = originalIsEnabledXdsHttpConnect; loadBalancer.shutdown(); + if (xdsDepManager != null) { + xdsDepManager.shutdown(); + } + assertThat(xdsClient.getSubscribedResourcesMetadataSnapshot().get()).isEmpty(); + xdsClient.shutdown(); + assertThat(childBalancers).isEmpty(); assertThat(resolvers).isEmpty(); - assertThat(xdsClient.watchers).isEmpty(); - assertThat(xdsClientRefs).isEqualTo(0); assertThat(fakeClock.getPendingTasks()).isEmpty(); } @Test - public void edsClustersWithRingHashEndpointLbPolicy() { - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(edsDiscoveryMechanism1), ringHash); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(childBalancers).isEmpty(); + public void edsClustersWithRingHashEndpointLbPolicy_oppositePickFirstWeightedShuffling() + throws Exception { + boolean original = CdsLoadBalancer2.pickFirstWeightedShuffling; + CdsLoadBalancer2.pickFirstWeightedShuffling = !CdsLoadBalancer2.pickFirstWeightedShuffling; + try { + edsClustersWithRingHashEndpointLbPolicy(); + } finally { + CdsLoadBalancer2.pickFirstWeightedShuffling = original; + } + } + + @Test + public void edsClustersWithRingHashEndpointLbPolicy() throws Exception { + boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; + LoadStatsManager2.isEnabledOrcaLrsPropagation = true; + List metricSpecs = Arrays.asList("cpu_utilization"); + BackendMetricPropagation backendMetricPropagation = + BackendMetricPropagation.fromMetricSpecs(metricSpecs); + Cluster cluster = EDS_CLUSTER.toBuilder() + .setLbPolicy(Cluster.LbPolicy.RING_HASH) + .setRingHashLbConfig(Cluster.RingHashLbConfig.newBuilder() + .setMinimumRingSize(UInt64Value.of(10)) + .setMaximumRingSize(UInt64Value.of(100)) + .build()) + .addAllLrsReportEndpointMetrics(metricSpecs) + .build(); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(10)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8080)) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.2", 8080))) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(50)) + .setLocality(LOCALITY2) + .addLbEndpoints(newSocketLbEndpoint("127.0.1.1", 8080) + .setLoadBalancingWeight(UInt32Value.of(60)))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, cluster)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); - // One priority with two localities of different weights. - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - EquivalentAddressGroup endpoint3 = makeAddress("endpoint-addr-3"); - LocalityLbEndpoints localityLbEndpoints1 = - LocalityLbEndpoints.create( - Arrays.asList( - LbEndpoint.create(endpoint1, 0 /* loadBalancingWeight */, true), - LbEndpoint.create(endpoint2, 0 /* loadBalancingWeight */, true)), - 10 /* localityWeight */, 1 /* priority */); - LocalityLbEndpoints localityLbEndpoints2 = - LocalityLbEndpoints.create( - Collections.singletonList( - LbEndpoint.create(endpoint3, 60 /* loadBalancingWeight */, true)), - 50 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, - ImmutableMap.of(locality1, localityLbEndpoints1, locality2, localityLbEndpoints2)); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.addresses).hasSize(3); EquivalentAddressGroup addr1 = childBalancer.addresses.get(0); EquivalentAddressGroup addr2 = childBalancer.addresses.get(1); EquivalentAddressGroup addr3 = childBalancer.addresses.get(2); - // Endpoints in locality1 have no endpoint-level weight specified, so all endpoints within - // locality1 are equally weighted. - assertThat(addr1.getAddresses()).isEqualTo(endpoint1.getAddresses()); - assertThat(addr1.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10); - assertThat(addr2.getAddresses()).isEqualTo(endpoint2.getAddresses()); - assertThat(addr2.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10); - assertThat(addr3.getAddresses()).isEqualTo(endpoint3.getAddresses()); - assertThat(addr3.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(50 * 60); + // Endpoints in LOCALITY1 have no endpoint-level weight specified, so all endpoints within + // LOCALITY1 are equally weighted. + assertThat(addr1.getAddresses()) + .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.1", 8080))); + assertThat(addr1.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); + assertThat(addr2.getAddresses()) + .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.2", 8080))); + assertThat(addr2.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); + assertThat(addr3.getAddresses()) + .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.1.1", 8080))); + assertThat(addr3.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x6AAAAAAA /* 5/6 */ : 50 * 60); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; - assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER1 + "[child1]"); + assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER + "[child1]"); PriorityChildConfig priorityChildConfig = Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); assertThat(priorityChildConfig.ignoreReresolution).isTrue(); @@ -291,8 +338,10 @@ public void edsClustersWithRingHashEndpointLbPolicy() { .isEqualTo(CLUSTER_IMPL_POLICY_NAME); ClusterImplConfig clusterImplConfig = (ClusterImplConfig) GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig.childConfig); - assertClusterImplConfig(clusterImplConfig, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, - tlsContext, Collections.emptyList(), "ring_hash_experimental"); + assertClusterImplConfig(clusterImplConfig, CLUSTER, EDS_SERVICE_NAME, null, null, + null, Collections.emptyList(), "ring_hash_experimental"); + assertThat(clusterImplConfig.backendMetricPropagation).isEqualTo(backendMetricPropagation); + LoadStatsManager2.isEnabledOrcaLrsPropagation = originalVal; RingHashConfig ringHashConfig = (RingHashConfig) GracefulSwitchLoadBalancerAccessor.getChildConfig(clusterImplConfig.childConfig); assertThat(ringHashConfig.minRingSize).isEqualTo(10L); @@ -301,39 +350,42 @@ public void edsClustersWithRingHashEndpointLbPolicy() { @Test public void edsClustersWithLeastRequestEndpointLbPolicy() { - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(edsDiscoveryMechanism1), leastRequest); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(childBalancers).isEmpty(); - + Cluster cluster = EDS_CLUSTER.toBuilder() + .setLbPolicy(Cluster.LbPolicy.LEAST_REQUEST) + .build(); // Simple case with one priority and one locality - EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Arrays.asList( - LbEndpoint.create(endpoint, 0 /* loadBalancingWeight */, true)), - 100 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, - ImmutableMap.of(locality1, localityLbEndpoints)); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8080))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, cluster)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.addresses).hasSize(1); EquivalentAddressGroup addr = childBalancer.addresses.get(0); - assertThat(addr.getAddresses()).isEqualTo(endpoint.getAddresses()); + assertThat(addr.getAddresses()) + .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.1", 8080))); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; - assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER1 + "[child1]"); + assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER + "[child1]"); PriorityChildConfig priorityChildConfig = Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider(priorityChildConfig.childConfig) - .getPolicyName()) + .getPolicyName()) .isEqualTo(CLUSTER_IMPL_POLICY_NAME); ClusterImplConfig clusterImplConfig = (ClusterImplConfig) GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig.childConfig); - assertClusterImplConfig(clusterImplConfig, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, - tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); + assertClusterImplConfig(clusterImplConfig, CLUSTER, EDS_SERVICE_NAME, null, null, + null, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); WrrLocalityConfig wrrLocalityConfig = (WrrLocalityConfig) GracefulSwitchLoadBalancerAccessor.getChildConfig(clusterImplConfig.childConfig); LoadBalancerProvider childProvider = @@ -342,154 +394,172 @@ public void edsClustersWithLeastRequestEndpointLbPolicy() { assertThat( childBalancer.addresses.get(0).getAttributes() - .get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)).isEqualTo(100); + .get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY_WEIGHT)).isEqualTo(100); } @Test - public void edsClustersWithOutlierDetection() { - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(edsDiscoveryMechanismWithOutlierDetection), leastRequest); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(childBalancers).isEmpty(); - + public void edsClustersEndpointHostname_addedToAddressAttribute() { // Simple case with one priority and one locality - EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Arrays.asList( - LbEndpoint.create(endpoint, 0 /* loadBalancingWeight */, true)), - 100 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, - ImmutableMap.of(locality1, localityLbEndpoints)); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setHostname("hostname1") + .setAddress(newAddress("127.0.0.1", 8000))))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.addresses).hasSize(1); - EquivalentAddressGroup addr = childBalancer.addresses.get(0); - assertThat(addr.getAddresses()).isEqualTo(endpoint.getAddresses()); - assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); - PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; - assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER1 + "[child1]"); - PriorityChildConfig priorityChildConfig = - Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); - - // The child config for priority should be outlier detection. - assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider(priorityChildConfig.childConfig) - .getPolicyName()) - .isEqualTo("outlier_detection_experimental"); - OutlierDetectionLoadBalancerConfig outlierDetectionConfig = (OutlierDetectionLoadBalancerConfig) - GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig.childConfig); - - // The outlier detection config should faithfully represent what came down from xDS. - assertThat(outlierDetectionConfig.intervalNanos).isEqualTo(outlierDetection.intervalNanos()); - assertThat(outlierDetectionConfig.baseEjectionTimeNanos).isEqualTo( - outlierDetection.baseEjectionTimeNanos()); - assertThat(outlierDetectionConfig.baseEjectionTimeNanos).isEqualTo( - outlierDetection.baseEjectionTimeNanos()); - assertThat(outlierDetectionConfig.maxEjectionTimeNanos).isEqualTo( - outlierDetection.maxEjectionTimeNanos()); - assertThat(outlierDetectionConfig.maxEjectionPercent).isEqualTo( - outlierDetection.maxEjectionPercent()); - - OutlierDetectionLoadBalancerConfig.SuccessRateEjection successRateEjection - = outlierDetectionConfig.successRateEjection; - assertThat(successRateEjection.stdevFactor).isEqualTo( - outlierDetection.successRateEjection().stdevFactor()); - assertThat(successRateEjection.enforcementPercentage).isEqualTo( - outlierDetection.successRateEjection().enforcementPercentage()); - assertThat(successRateEjection.minimumHosts).isEqualTo( - outlierDetection.successRateEjection().minimumHosts()); - assertThat(successRateEjection.requestVolume).isEqualTo( - outlierDetection.successRateEjection().requestVolume()); - - OutlierDetectionLoadBalancerConfig.FailurePercentageEjection failurePercentageEjection - = outlierDetectionConfig.failurePercentageEjection; - assertThat(failurePercentageEjection.threshold).isEqualTo( - outlierDetection.failurePercentageEjection().threshold()); - assertThat(failurePercentageEjection.enforcementPercentage).isEqualTo( - outlierDetection.failurePercentageEjection().enforcementPercentage()); - assertThat(failurePercentageEjection.minimumHosts).isEqualTo( - outlierDetection.failurePercentageEjection().minimumHosts()); - assertThat(failurePercentageEjection.requestVolume).isEqualTo( - outlierDetection.failurePercentageEjection().requestVolume()); - - // The wrapped configuration should not have been tampered with. - ClusterImplConfig clusterImplConfig = (ClusterImplConfig) - GracefulSwitchLoadBalancerAccessor.getChildConfig(outlierDetectionConfig.childConfig); - assertClusterImplConfig(clusterImplConfig, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, - tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); - WrrLocalityConfig wrrLocalityConfig = (WrrLocalityConfig) - GracefulSwitchLoadBalancerAccessor.getChildConfig(clusterImplConfig.childConfig); - LoadBalancerProvider childProvider = - GracefulSwitchLoadBalancerAccessor.getChildProvider(wrrLocalityConfig.childConfig); - assertThat(childProvider.getPolicyName()).isEqualTo("least_request_experimental"); assertThat( childBalancer.addresses.get(0).getAttributes() - .get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)).isEqualTo(100); + .get(XdsInternalAttributes.ATTR_ADDRESS_NAME)).isEqualTo("hostname1"); } + @Test + public void endpointAddressRewritten_whenProxyMetadataIsInEndpointMetadata() { + XdsClusterResource.isEnabledXdsHttpConnect = true; + Cluster cluster = EDS_CLUSTER.toBuilder() + .setTransportSocket(TransportSocket.newBuilder() + .setName( + "type.googleapis.com/" + Http11ProxyUpstreamTransport.getDescriptor().getFullName()) + .setTypedConfig(Any.pack(Http11ProxyUpstreamTransport.getDefaultInstance()))) + .build(); + // Proxy address in endpointMetadata, and no proxy in locality metadata + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8080) + .setMetadata(Metadata.newBuilder() + .putTypedFilterMetadata( + "envoy.http11_proxy_transport_socket.proxy_address", + Any.pack(newAddress("127.0.0.2", 8081).build())))) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.3", 8082))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, cluster)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); + FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); + + // Get the rewritten address + java.net.SocketAddress rewrittenAddress = + childBalancer.addresses.get(0).getAddresses().get(0); + assertThat(rewrittenAddress).isInstanceOf(HttpConnectProxiedSocketAddress.class); + HttpConnectProxiedSocketAddress proxiedSocket = + (HttpConnectProxiedSocketAddress) rewrittenAddress; + + // Assert that the target address is the original address + assertThat(proxiedSocket.getTargetAddress()).isEqualTo(newInetSocketAddress("127.0.0.1", 8080)); + + // Assert that the proxy address is correctly set + assertThat(proxiedSocket.getProxyAddress()).isEqualTo(newInetSocketAddress("127.0.0.2", 8081)); + + // Check the non-rewritten address + java.net.SocketAddress normalAddress = childBalancer.addresses.get(1).getAddresses().get(0); + assertThat(normalAddress).isEqualTo(newInetSocketAddress("127.0.0.3", 8082)); + } + + @Test + public void endpointAddressRewritten_whenProxyMetadataIsInLocalityMetadata() { + XdsClusterResource.isEnabledXdsHttpConnect = true; + Cluster cluster = EDS_CLUSTER.toBuilder() + .setTransportSocket(TransportSocket.newBuilder() + .setName( + "type.googleapis.com/" + Http11ProxyUpstreamTransport.getDescriptor().getFullName()) + .setTypedConfig(Any.pack(Http11ProxyUpstreamTransport.getDefaultInstance()))) + .build(); + // No proxy address in endpointMetadata, and proxy in locality metadata + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8080)) + .setMetadata(Metadata.newBuilder() + .putTypedFilterMetadata( + "envoy.http11_proxy_transport_socket.proxy_address", + Any.pack(newAddress("127.0.0.2", 8081).build())))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, cluster)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); + FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); + + // Get the rewritten address + java.net.SocketAddress rewrittenAddress = childBalancer.addresses.get(0).getAddresses().get(0); + + // Assert that the address was rewritten + assertThat(rewrittenAddress).isInstanceOf(HttpConnectProxiedSocketAddress.class); + HttpConnectProxiedSocketAddress proxiedSocket = + (HttpConnectProxiedSocketAddress) rewrittenAddress; + + // Assert that the target address is the original address + assertThat(proxiedSocket.getTargetAddress()).isEqualTo(newInetSocketAddress("127.0.0.1", 8080)); + + // Assert that the proxy address is correctly set from locality metadata + assertThat(proxiedSocket.getProxyAddress()).isEqualTo(newInetSocketAddress("127.0.0.2", 8081)); + } @Test public void onlyEdsClusters_receivedEndpoints() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, edsDiscoveryMechanism2), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1, EDS_SERVICE_NAME2); - assertThat(childBalancers).isEmpty(); - // CLUSTER1 has priority 1 (priority3), which has locality 2, which has endpoint3. - // CLUSTER2 has priority 1 (priority1) and 2 (priority2); priority1 has locality1, - // which has endpoint1 and endpoint2; priority2 has locality3, which has endpoint4. - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - EquivalentAddressGroup endpoint3 = makeAddress("endpoint-addr-3"); - EquivalentAddressGroup endpoint4 = makeAddress("endpoint-addr-4"); - LocalityLbEndpoints localityLbEndpoints1 = - LocalityLbEndpoints.create( - Arrays.asList( - LbEndpoint.create(endpoint1, 100, true), - LbEndpoint.create(endpoint2, 100, true)), - 70 /* localityWeight */, 1 /* priority */); - LocalityLbEndpoints localityLbEndpoints2 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint3, 100, true)), - 10 /* localityWeight */, 1 /* priority */); - LocalityLbEndpoints localityLbEndpoints3 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint4, 100, true)), - 20 /* localityWeight */, 2 /* priority */); - String priority1 = CLUSTER2 + "[child1]"; - String priority2 = CLUSTER2 + "[child2]"; - String priority3 = CLUSTER1 + "[child1]"; - - // CLUSTER2: locality1 with priority 1 and locality3 with priority 2. - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME2, - ImmutableMap.of(locality1, localityLbEndpoints1, locality3, localityLbEndpoints3)); - assertThat(childBalancers).isEmpty(); // not created until all clusters resolved - - // CLUSTER1: locality2 with priority 1. - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality2, localityLbEndpoints2)); - - // Endpoints of all clusters have been resolved. + // Has two localities with different priorities + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(70)) + .setPriority(0) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8080)) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.2", 8080))) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(30)) + .setPriority(1) + .setLocality(LOCALITY2) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.3", 8080))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + + String priority1 = CLUSTER + "[child1]"; + String priority2 = CLUSTER + "[child2]"; + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; assertThat(priorityLbConfig.priorities) - .containsExactly(priority3, priority1, priority2).inOrder(); + .containsExactly(priority1, priority2).inOrder(); PriorityChildConfig priorityChildConfig1 = priorityLbConfig.childConfigs.get(priority1); assertThat(priorityChildConfig1.ignoreReresolution).isTrue(); assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider(priorityChildConfig1.childConfig) - .getPolicyName()) + .getPolicyName()) .isEqualTo(CLUSTER_IMPL_POLICY_NAME); ClusterImplConfig clusterImplConfig1 = (ClusterImplConfig) GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig1.childConfig); - assertClusterImplConfig(clusterImplConfig1, CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_INFO, 200L, - tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); + assertClusterImplConfig(clusterImplConfig1, CLUSTER, EDS_SERVICE_NAME, null, null, + null, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); WrrLocalityConfig wrrLocalityConfig1 = (WrrLocalityConfig) GracefulSwitchLoadBalancerAccessor.getChildConfig(clusterImplConfig1.childConfig); LoadBalancerProvider childProvider1 = @@ -499,63 +569,60 @@ public void onlyEdsClusters_receivedEndpoints() { PriorityChildConfig priorityChildConfig2 = priorityLbConfig.childConfigs.get(priority2); assertThat(priorityChildConfig2.ignoreReresolution).isTrue(); assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider(priorityChildConfig2.childConfig) - .getPolicyName()) + .getPolicyName()) .isEqualTo(CLUSTER_IMPL_POLICY_NAME); ClusterImplConfig clusterImplConfig2 = (ClusterImplConfig) GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig2.childConfig); - assertClusterImplConfig(clusterImplConfig2, CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_INFO, 200L, - tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); + assertClusterImplConfig(clusterImplConfig2, CLUSTER, EDS_SERVICE_NAME, null, null, + null, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); WrrLocalityConfig wrrLocalityConfig2 = (WrrLocalityConfig) GracefulSwitchLoadBalancerAccessor.getChildConfig(clusterImplConfig1.childConfig); LoadBalancerProvider childProvider2 = GracefulSwitchLoadBalancerAccessor.getChildProvider(wrrLocalityConfig2.childConfig); assertThat(childProvider2.getPolicyName()).isEqualTo("round_robin"); - PriorityChildConfig priorityChildConfig3 = priorityLbConfig.childConfigs.get(priority3); - assertThat(priorityChildConfig3.ignoreReresolution).isTrue(); - assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider(priorityChildConfig3.childConfig) - .getPolicyName()) - .isEqualTo(CLUSTER_IMPL_POLICY_NAME); - ClusterImplConfig clusterImplConfig3 = (ClusterImplConfig) - GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig3.childConfig); - assertClusterImplConfig(clusterImplConfig3, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, - tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); WrrLocalityConfig wrrLocalityConfig3 = (WrrLocalityConfig) GracefulSwitchLoadBalancerAccessor.getChildConfig(clusterImplConfig1.childConfig); LoadBalancerProvider childProvider3 = GracefulSwitchLoadBalancerAccessor.getChildProvider(wrrLocalityConfig3.childConfig); assertThat(childProvider3.getPolicyName()).isEqualTo("round_robin"); + io.grpc.xds.client.Locality locality1 = io.grpc.xds.client.Locality.create( + LOCALITY1.getRegion(), LOCALITY1.getZone(), LOCALITY1.getSubZone()); + io.grpc.xds.client.Locality locality2 = io.grpc.xds.client.Locality.create( + LOCALITY2.getRegion(), LOCALITY2.getZone(), LOCALITY2.getSubZone()); for (EquivalentAddressGroup eag : childBalancer.addresses) { - if (eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY) == locality1) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)) + io.grpc.xds.client.Locality locality = + eag.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY); + if (locality.equals(locality1)) { + assertThat(eag.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY_WEIGHT)) .isEqualTo(70); - } - if (eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY) == locality2) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)) - .isEqualTo(10); - } - if (eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY) == locality3) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)) - .isEqualTo(20); + } else if (locality.equals(locality2)) { + assertThat(eag.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY_WEIGHT)) + .isEqualTo(30); + } else { + throw new AssertionError("Unexpected locality region: " + locality.region()); } } } @SuppressWarnings("unchecked") - private void verifyEdsPriorityNames(List want, - Map... updates) { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism2), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME2); - assertThat(childBalancers).isEmpty(); - - for (Map update: updates) { - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME2, - update); + private void verifyEdsPriorityNames(List want, List... updates) { + Iterator edsUpdates = Arrays.asList(updates).stream() + .map(update -> ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addAllEndpoints(update) + .build()) + .iterator(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, edsUpdates.next())); + startXdsDepManager(); + + while (edsUpdates.hasNext()) { + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, edsUpdates.next())); } + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); @@ -566,218 +633,273 @@ private void verifyEdsPriorityNames(List want, @Test @SuppressWarnings("unchecked") public void edsUpdatePriorityName_twoPriorities() { - verifyEdsPriorityNames(Arrays.asList(CLUSTER2 + "[child1]", CLUSTER2 + "[child2]"), - ImmutableMap.of(locality1, createEndpoints(1), - locality2, createEndpoints(2) - )); + verifyEdsPriorityNames(Arrays.asList(CLUSTER + "[child1]", CLUSTER + "[child2]"), + Arrays.asList(createEndpoints(LOCALITY1, 0), createEndpoints(LOCALITY2, 1))); } @Test @SuppressWarnings("unchecked") public void edsUpdatePriorityName_addOnePriority() { - verifyEdsPriorityNames(Arrays.asList(CLUSTER2 + "[child2]"), - ImmutableMap.of(locality1, createEndpoints(1)), - ImmutableMap.of(locality2, createEndpoints(1) - )); + verifyEdsPriorityNames(Arrays.asList(CLUSTER + "[child2]"), + Arrays.asList(createEndpoints(LOCALITY1, 0)), + Arrays.asList(createEndpoints(LOCALITY2, 0))); } @Test @SuppressWarnings("unchecked") public void edsUpdatePriorityName_swapTwoPriorities() { - verifyEdsPriorityNames(Arrays.asList(CLUSTER2 + "[child2]", CLUSTER2 + "[child1]", - CLUSTER2 + "[child3]"), - ImmutableMap.of(locality1, createEndpoints(1), - locality2, createEndpoints(2), - locality3, createEndpoints(3) - ), - ImmutableMap.of(locality1, createEndpoints(2), - locality2, createEndpoints(1), - locality3, createEndpoints(3)) - ); + verifyEdsPriorityNames(Arrays.asList(CLUSTER + "[child2]", CLUSTER + "[child1]", + CLUSTER + "[child3]"), + Arrays.asList( + createEndpoints(LOCALITY1, 0), + createEndpoints(LOCALITY2, 1), + createEndpoints(LOCALITY3, 2)), + Arrays.asList( + createEndpoints(LOCALITY1, 1), + createEndpoints(LOCALITY2, 0), + createEndpoints(LOCALITY3, 2))); } @Test @SuppressWarnings("unchecked") public void edsUpdatePriorityName_mergeTwoPriorities() { - verifyEdsPriorityNames(Arrays.asList(CLUSTER2 + "[child3]", CLUSTER2 + "[child1]"), - ImmutableMap.of(locality1, createEndpoints(1), - locality3, createEndpoints(3), - locality2, createEndpoints(2)), - ImmutableMap.of(locality1, createEndpoints(2), - locality3, createEndpoints(1), - locality2, createEndpoints(1) - )); + verifyEdsPriorityNames(Arrays.asList(CLUSTER + "[child3]", CLUSTER + "[child1]"), + Arrays.asList( + createEndpoints(LOCALITY1, 0), + createEndpoints(LOCALITY3, 2), + createEndpoints(LOCALITY2, 1)), + Arrays.asList( + createEndpoints(LOCALITY1, 1), + createEndpoints(LOCALITY3, 0), + createEndpoints(LOCALITY2, 0))); } - private LocalityLbEndpoints createEndpoints(int priority) { - return LocalityLbEndpoints.create( - Arrays.asList( - LbEndpoint.create(makeAddress("endpoint-addr-1"), 100, true), - LbEndpoint.create(makeAddress("endpoint-addr-2"), 100, true)), - 70 /* localityWeight */, priority /* priority */); + private LocalityLbEndpoints createEndpoints(Locality locality, int priority) { + return LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(70)) + .setLocality(locality) + .setPriority(priority) + .addLbEndpoints(newSocketLbEndpoint("127.0." + priority + ".1", 8080)) + .build(); } @Test public void onlyEdsClusters_resourceNeverExist_returnErrorPicker() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, edsDiscoveryMechanism2), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1, EDS_SERVICE_NAME2); - assertThat(childBalancers).isEmpty(); - reset(helper); - xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME1); - verify(helper, never()).updateBalancingState( - any(ConnectivityState.class), any(SubchannelPicker.class)); // wait for CLUSTER2's results + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); + startXdsDepManager(); - xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME2); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker( - pickerCaptor.getValue(), - Status.UNAVAILABLE.withDescription( - "No usable endpoint from cluster(s): " + Arrays.asList(CLUSTER1, CLUSTER2)), - null); + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + ": NOT_FOUND. " + + "Details: Timed out waiting for resource " + CLUSTER + " from xDS server nodeID: node-id"; + Status expectedError = Status.UNAVAILABLE.withDescription(expectedDescription); + assertPicker(pickerCaptor.getValue(), expectedError, null); } @Test - public void onlyEdsClusters_allResourcesRevoked_shutDownChildLbPolicy() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, edsDiscoveryMechanism2), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1, EDS_SERVICE_NAME2); - assertThat(childBalancers).isEmpty(); - reset(helper); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - LocalityLbEndpoints localityLbEndpoints1 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint1, 100, true)), - 10 /* localityWeight */, 1 /* priority */); - LocalityLbEndpoints localityLbEndpoints2 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint2, 100, true)), - 20 /* localityWeight */, 2 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints1)); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME2, Collections.singletonMap(locality2, localityLbEndpoints2)); + public void cdsMissing_handledDirectly() { + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8000))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + + startXdsDepManager(); + assertThat(childBalancers).hasSize(0); // no child LB policy created + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + ": NOT_FOUND. " + + "Details: Timed out waiting for resource " + CLUSTER + " from xDS server nodeID: node-id"; + Status expectedError = Status.UNAVAILABLE.withDescription(expectedDescription); + assertPicker(pickerCaptor.getValue(), expectedError, null); + assertPicker(pickerCaptor.getValue(), expectedError, null); + } + + @Test + public void cdsRevoked_handledDirectly() { + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8000))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + + startXdsDepManager(); assertThat(childBalancers).hasSize(1); // child LB policy created FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(((PriorityLbConfig) childBalancer.config).priorities).hasSize(2); - assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), childBalancer.addresses); + assertThat(((PriorityLbConfig) childBalancer.config).priorities).hasSize(1); + assertThat(childBalancer.addresses).hasSize(1); + assertAddressesEqual( + Arrays.asList(newInetSocketAddressEag("127.0.0.1", 8000)), + childBalancer.addresses); + + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + ": NOT_FOUND. " + + "Details: Resource " + CLUSTER + " does not exist nodeID: node-id"; + Status expectedError = Status.UNAVAILABLE.withDescription(expectedDescription); + assertPicker(pickerCaptor.getValue(), expectedError, null); + assertThat(childBalancer.shutdown).isTrue(); + } + + @Test + public void edsMissing_failsRpcs() { + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of()); - xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME2); - xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME1); + startXdsDepManager(); + assertThat(childBalancers).hasSize(0); // Graceful switch handles it, so no child policies yet verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status expectedError = Status.UNAVAILABLE.withDescription( - "No usable endpoint from cluster(s): " + Arrays.asList(CLUSTER1, CLUSTER2)); + String expectedDescription = "Error retrieving EDS resource " + EDS_SERVICE_NAME + + ": NOT_FOUND. Details: Timed out waiting for resource " + EDS_SERVICE_NAME + + " from xDS server nodeID: node-id"; + Status expectedError = Status.UNAVAILABLE.withDescription(expectedDescription); assertPicker(pickerCaptor.getValue(), expectedError, null); } + @Test + public void logicalDnsLookupFailed_failsRpcs() { + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, LOGICAL_DNS_CLUSTER)); + startXdsDepManager(new CdsConfig(CLUSTER), /* forwardTime= */ false); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME + ":9000"); + assertThat(childBalancers).isEmpty(); + Status status = Status.UNAVAILABLE.withDescription("OH NO! Who would have guessed?"); + resolver.deliverError(status); + + assertThat(childBalancers).hasSize(0); // Graceful switch handles it, so no child policies yet + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + assertPicker(pickerCaptor.getValue(), status, null); + } + @Test public void handleEdsResource_ignoreUnhealthyEndpoints() { - ClusterResolverConfig config = - new ClusterResolverConfig(Collections.singletonList(edsDiscoveryMechanism1), roundRobin); - deliverLbConfig(config); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Arrays.asList( - LbEndpoint.create(endpoint1, 100, false /* isHealthy */), - LbEndpoint.create(endpoint2, 100, true /* isHealthy */)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8000) + .setHealthStatus(HealthStatus.UNHEALTHY)) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.2", 8000))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.addresses).hasSize(1); - assertAddressesEqual(Collections.singletonList(endpoint2), childBalancer.addresses); + assertAddressesEqual( + Arrays.asList(new EquivalentAddressGroup(newInetSocketAddress("127.0.0.2", 8000))), + childBalancer.addresses); } @Test public void handleEdsResource_ignoreLocalitiesWithNoHealthyEndpoints() { - ClusterResolverConfig config = - new ClusterResolverConfig(Collections.singletonList(edsDiscoveryMechanism1), roundRobin); - deliverLbConfig(config); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - LocalityLbEndpoints localityLbEndpoints1 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint1, 100, false /* isHealthy */)), - 10 /* localityWeight */, 1 /* priority */); - LocalityLbEndpoints localityLbEndpoints2 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint2, 100, true /* isHealthy */)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, - ImmutableMap.of(locality1, localityLbEndpoints1, locality2, localityLbEndpoints2)); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8000) + .setHealthStatus(HealthStatus.UNHEALTHY))) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY2) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.2", 8000))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); + io.grpc.xds.client.Locality locality2 = io.grpc.xds.client.Locality.create( + LOCALITY2.getRegion(), LOCALITY2.getZone(), LOCALITY2.getSubZone()); for (EquivalentAddressGroup eag : childBalancer.addresses) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY)).isEqualTo(locality2); + assertThat(eag.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY)) + .isEqualTo(locality2); } } @Test public void handleEdsResource_ignorePrioritiesWithNoHealthyEndpoints() { - ClusterResolverConfig config = - new ClusterResolverConfig(Collections.singletonList(edsDiscoveryMechanism1), roundRobin); - deliverLbConfig(config); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - LocalityLbEndpoints localityLbEndpoints1 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint1, 100, false /* isHealthy */)), - 10 /* localityWeight */, 1 /* priority */); - LocalityLbEndpoints localityLbEndpoints2 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint2, 200, true /* isHealthy */)), - 10 /* localityWeight */, 2 /* priority */); - String priority2 = CLUSTER1 + "[child2]"; - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, - ImmutableMap.of(locality1, localityLbEndpoints1, locality2, localityLbEndpoints2)); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .setPriority(0) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8000) + .setHealthStatus(HealthStatus.UNHEALTHY))) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY2) + .setPriority(1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.2", 8000))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + String priority2 = CLUSTER + "[child2]"; FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(((PriorityLbConfig) childBalancer.config).priorities).containsExactly(priority2); } @Test public void handleEdsResource_noHealthyEndpoint() { - ClusterResolverConfig config = - new ClusterResolverConfig(Collections.singletonList(edsDiscoveryMechanism1), roundRobin); - deliverLbConfig(config); - EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint, 100, false /* isHealthy */)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment(EDS_SERVICE_NAME1, - Collections.singletonMap(locality1, localityLbEndpoints)); // single endpoint, unhealthy + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8000) + .setHealthStatus(HealthStatus.UNHEALTHY))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); - assertThat(childBalancers).isEmpty(); + assertThat(childBalancers).hasSize(0); // Graceful switch handles it, so no child policies yet verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker( - pickerCaptor.getValue(), - Status.UNAVAILABLE.withDescription( - "No usable endpoint from cluster(s): " + Collections.singleton(CLUSTER1)), - null); + Status expectedStatus = Status.UNAVAILABLE + .withDescription("No usable endpoint from cluster: " + CLUSTER); + assertPicker(pickerCaptor.getValue(), expectedStatus, null); } @Test public void onlyLogicalDnsCluster_endpointsResolved() { - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); + boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; + LoadStatsManager2.isEnabledOrcaLrsPropagation = true; + List metricSpecs = Arrays.asList("cpu_utilization"); + BackendMetricPropagation backendMetricPropagation = + BackendMetricPropagation.fromMetricSpecs(metricSpecs); + Cluster logicalDnsClusterWithMetrics = LOGICAL_DNS_CLUSTER.toBuilder() + .addAllLrsReportEndpointMetrics(metricSpecs) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, logicalDnsClusterWithMetrics)); + startXdsDepManager(new CdsConfig(CLUSTER), /* forwardTime= */ false); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME + ":9000"); assertThat(childBalancers).isEmpty(); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); + resolver.deliverEndpointAddresses(Arrays.asList( + newInetSocketAddressEag("127.0.2.1", 9000), newInetSocketAddressEag("127.0.2.2", 9000))); + fakeClock.forwardTime(10, TimeUnit.MINUTES); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); @@ -786,341 +908,142 @@ public void onlyLogicalDnsCluster_endpointsResolved() { PriorityChildConfig priorityChildConfig = priorityLbConfig.childConfigs.get(priority); assertThat(priorityChildConfig.ignoreReresolution).isFalse(); assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider(priorityChildConfig.childConfig) - .getPolicyName()) + .getPolicyName()) .isEqualTo(CLUSTER_IMPL_POLICY_NAME); ClusterImplConfig clusterImplConfig = (ClusterImplConfig) GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig.childConfig); - assertClusterImplConfig(clusterImplConfig, CLUSTER_DNS, null, LRS_SERVER_INFO, 300L, null, - Collections.emptyList(), "pick_first"); - assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), childBalancer.addresses); + assertClusterImplConfig(clusterImplConfig, CLUSTER, null, null, null, null, + Collections.emptyList(), "wrr_locality_experimental"); + assertThat(clusterImplConfig.backendMetricPropagation).isEqualTo(backendMetricPropagation); + LoadStatsManager2.isEnabledOrcaLrsPropagation = originalVal; + assertAddressesEqual( + Arrays.asList(new EquivalentAddressGroup(Arrays.asList( + newInetSocketAddress("127.0.2.1", 9000), newInetSocketAddress("127.0.2.2", 9000)))), + childBalancer.addresses); + assertThat(childBalancer.addresses.get(0).getAttributes() + .get(XdsInternalAttributes.ATTR_ADDRESS_NAME)).isEqualTo(DNS_HOST_NAME + ":9000"); } @Test public void onlyLogicalDnsCluster_handleRefreshNameResolution() { - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, LOGICAL_DNS_CLUSTER)); + startXdsDepManager(new CdsConfig(CLUSTER), /* forwardTime= */ false); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME + ":9000"); assertThat(childBalancers).isEmpty(); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); - assertThat(resolver.refreshCount).isEqualTo(0); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - childBalancer.helper.refreshNameResolution(); - assertThat(resolver.refreshCount).isEqualTo(1); - } + resolver.deliverEndpointAddresses(Arrays.asList(newInetSocketAddressEag("127.0.2.1", 9000))); + fakeClock.forwardTime(10, TimeUnit.MINUTES); - @Test - public void onlyLogicalDnsCluster_resolutionError_backoffAndRefresh() { - InOrder inOrder = Mockito.inOrder(helper, backoffPolicyProvider, - backoffPolicy1, backoffPolicy2); - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - Status error = Status.UNAVAILABLE.withDescription("cannot reach DNS server"); - resolver.deliverError(error); - inOrder.verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), error, null); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(resolver.refreshCount).isEqualTo(0); - inOrder.verify(backoffPolicyProvider).get(); - inOrder.verify(backoffPolicy1).nextBackoffNanos(); - assertThat(fakeClock.getPendingTasks()).hasSize(1); - assertThat(Iterables.getOnlyElement(fakeClock.getPendingTasks()).getDelay(TimeUnit.SECONDS)) - .isEqualTo(1L); - fakeClock.forwardTime(1L, TimeUnit.SECONDS); - assertThat(resolver.refreshCount).isEqualTo(1); - - error = Status.UNKNOWN.withDescription("I am lost"); - resolver.deliverError(error); - inOrder.verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - inOrder.verify(backoffPolicy1).nextBackoffNanos(); - assertPicker(pickerCaptor.getValue(), error, null); - assertThat(fakeClock.getPendingTasks()).hasSize(1); - assertThat(Iterables.getOnlyElement(fakeClock.getPendingTasks()).getDelay(TimeUnit.SECONDS)) - .isEqualTo(10L); - fakeClock.forwardTime(10L, TimeUnit.SECONDS); - assertThat(resolver.refreshCount).isEqualTo(2); - - // Succeed. - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); - assertThat(childBalancers).hasSize(1); - assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), - Iterables.getOnlyElement(childBalancers).addresses); - - assertThat(fakeClock.getPendingTasks()).isEmpty(); - inOrder.verifyNoMoreInteractions(); - } - - @Test - public void onlyLogicalDnsCluster_refreshNameResolutionRaceWithResolutionError() { - InOrder inOrder = Mockito.inOrder(backoffPolicyProvider, backoffPolicy1, backoffPolicy2); - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - EquivalentAddressGroup endpoint = makeAddress("endpoint-addr"); - resolver.deliverEndpointAddresses(Collections.singletonList(endpoint)); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertAddressesEqual(Collections.singletonList(endpoint), childBalancer.addresses); - assertThat(resolver.refreshCount).isEqualTo(0); - childBalancer.helper.refreshNameResolution(); assertThat(resolver.refreshCount).isEqualTo(1); - resolver.deliverError(Status.UNAVAILABLE.withDescription("I am lost")); - inOrder.verify(backoffPolicyProvider).get(); - inOrder.verify(backoffPolicy1).nextBackoffNanos(); - assertThat(fakeClock.getPendingTasks()).hasSize(1); - ScheduledTask task = Iterables.getOnlyElement(fakeClock.getPendingTasks()); - assertThat(task.getDelay(TimeUnit.SECONDS)).isEqualTo(1L); - - fakeClock.forwardTime( 100L, TimeUnit.MILLISECONDS); - childBalancer.helper.refreshNameResolution(); - assertThat(resolver.refreshCount).isEqualTo(2); - assertThat(task.isCancelled()).isTrue(); - assertThat(fakeClock.getPendingTasks()).isEmpty(); - resolver.deliverError(Status.UNAVAILABLE.withDescription("I am still lost")); - inOrder.verify(backoffPolicyProvider).get(); // active refresh resets backoff sequence - inOrder.verify(backoffPolicy2).nextBackoffNanos(); - task = Iterables.getOnlyElement(fakeClock.getPendingTasks()); - assertThat(task.getDelay(TimeUnit.SECONDS)).isEqualTo(5L); - - fakeClock.forwardTime(5L, TimeUnit.SECONDS); - assertThat(resolver.refreshCount).isEqualTo(3); - inOrder.verifyNoMoreInteractions(); } @Test - public void edsClustersAndLogicalDnsCluster_receivedEndpoints() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); // DNS endpoint - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); // DNS endpoint - EquivalentAddressGroup endpoint3 = makeAddress("endpoint-addr-3"); // EDS endpoint - resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint3, 100, true)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); + public void outlierDetection_disabledConfig() { + Cluster cluster = EDS_CLUSTER.toBuilder() + .setOutlierDetection(OutlierDetection.newBuilder() + .setEnforcingSuccessRate(UInt32Value.of(0)) + .setEnforcingFailurePercentage(UInt32Value.of(0))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, cluster)); + startXdsDepManager(); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(((PriorityLbConfig) childBalancer.config).priorities) - .containsExactly(CLUSTER1 + "[child1]", CLUSTER_DNS + "[child0]").inOrder(); - assertAddressesEqual(Arrays.asList(endpoint3, endpoint1, endpoint2), - childBalancer.addresses); // ordered by cluster then addresses - assertAddressesEqual(AddressFilter.filter(AddressFilter.filter( - childBalancer.addresses, CLUSTER1 + "[child1]"), - "{region=\"test-region-1\", zone=\"test-zone-1\", sub_zone=\"test-subzone-1\"}"), - Collections.singletonList(endpoint3)); - assertAddressesEqual(AddressFilter.filter(AddressFilter.filter( - childBalancer.addresses, CLUSTER_DNS + "[child0]"), - "{region=\"\", zone=\"\", sub_zone=\"\"}"), - Arrays.asList(endpoint1, endpoint2)); + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); + PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; + PriorityChildConfig priorityChildConfig = + Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); + OutlierDetectionLoadBalancerConfig outlier = (OutlierDetectionLoadBalancerConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig.childConfig); + assertThat(outlier.successRateEjection).isNull(); + assertThat(outlier.failurePercentageEjection).isNull(); } @Test - public void noEdsResourceExists_useDnsResolutionResults() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - reset(helper); - xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME1); - verify(helper, never()).updateBalancingState( - any(ConnectivityState.class), any(SubchannelPicker.class)); // wait for DNS results - - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); - assertThat(childBalancers).hasSize(1); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - String priority = Iterables.getOnlyElement( - ((PriorityLbConfig) childBalancer.config).priorities); - assertThat(priority).isEqualTo(CLUSTER_DNS + "[child0]"); - assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), childBalancer.addresses); - } + public void outlierDetection_fullConfig() { + Cluster cluster = EDS_CLUSTER.toBuilder() + .setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN) + .setOutlierDetection(OutlierDetection.newBuilder() + .setInterval(Duration.newBuilder().setNanos(101)) + .setBaseEjectionTime(Duration.newBuilder().setNanos(102)) + .setMaxEjectionTime(Duration.newBuilder().setNanos(103)) + .setMaxEjectionPercent(UInt32Value.of(80)) + .setSuccessRateStdevFactor(UInt32Value.of(105)) + .setEnforcingSuccessRate(UInt32Value.of(81)) + .setSuccessRateMinimumHosts(UInt32Value.of(107)) + .setSuccessRateRequestVolume(UInt32Value.of(108)) + .setFailurePercentageThreshold(UInt32Value.of(82)) + .setEnforcingFailurePercentage(UInt32Value.of(83)) + .setFailurePercentageMinimumHosts(UInt32Value.of(111)) + .setFailurePercentageRequestVolume(UInt32Value.of(112))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, cluster)); + startXdsDepManager(); - @Test - public void edsResourceRevoked_dnsResolutionError_shutDownChildLbPolicyAndReturnErrorPicker() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - reset(helper); - EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint, 100, true)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); - resolver.deliverError(Status.UNKNOWN.withDescription("I am lost")); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(((PriorityLbConfig) childBalancer.config).priorities) - .containsExactly(CLUSTER1 + "[child1]"); - assertAddressesEqual(Collections.singletonList(endpoint), childBalancer.addresses); - assertThat(childBalancer.shutdown).isFalse(); - xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME1); - assertThat(childBalancer.shutdown).isTrue(); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), - Status.UNAVAILABLE.withDescription("I am lost"), null); - } - - @Test - public void resolutionErrorAfterChildLbCreated_propagateErrorIfAllClustersEncounterError() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - reset(helper); - EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint, 100, true)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); - assertThat(childBalancers).isEmpty(); // not created until all clusters resolved. - - resolver.deliverError(Status.UNKNOWN.withDescription("I am lost")); - - // DNS resolution failed, but there are EDS endpoints can be used. - assertThat(childBalancers).hasSize(1); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); // child LB created - assertThat(childBalancer.upstreamError).isNull(); // should not propagate error to child LB - assertAddressesEqual(Collections.singletonList(endpoint), childBalancer.addresses); - - xdsClient.deliverError(Status.RESOURCE_EXHAUSTED.withDescription("out of memory")); - assertThat(childBalancer.upstreamError).isNotNull(); // last cluster's (DNS) error propagated - assertThat(childBalancer.upstreamError.getCode()).isEqualTo(Code.UNKNOWN); - assertThat(childBalancer.upstreamError.getDescription()).isEqualTo("I am lost"); - assertThat(childBalancer.shutdown).isFalse(); - verify(helper, never()).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), any(SubchannelPicker.class)); - } - - @Test - public void resolutionErrorBeforeChildLbCreated_returnErrorPickerIfAllClustersEncounterError() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - reset(helper); - xdsClient.deliverError(Status.UNIMPLEMENTED.withDescription("not found")); - assertThat(childBalancers).isEmpty(); - verify(helper, never()).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), any(SubchannelPicker.class)); // wait for DNS - Status dnsError = Status.UNKNOWN.withDescription("I am lost"); - resolver.deliverError(dnsError); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker( - pickerCaptor.getValue(), - Status.UNAVAILABLE.withDescription(dnsError.getDescription()), - null); - } - - @Test - public void resolutionErrorBeforeChildLbCreated_edsOnly_returnErrorPicker() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(childBalancers).isEmpty(); - reset(helper); - xdsClient.deliverError(Status.RESOURCE_EXHAUSTED.withDescription("OOM")); - assertThat(childBalancers).isEmpty(); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); - Status actualStatus = result.getStatus(); - assertThat(actualStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); - assertThat(actualStatus.getDescription()).contains("RESOURCE_EXHAUSTED: OOM"); + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); + PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; + PriorityChildConfig priorityChildConfig = + Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); + OutlierDetectionLoadBalancerConfig outlier = (OutlierDetectionLoadBalancerConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig.childConfig); + assertThat(outlier.intervalNanos).isEqualTo(101); + assertThat(outlier.baseEjectionTimeNanos).isEqualTo(102); + assertThat(outlier.maxEjectionTimeNanos).isEqualTo(103); + assertThat(outlier.maxEjectionPercent).isEqualTo(80); + assertThat(outlier.successRateEjection.stdevFactor).isEqualTo(105); + assertThat(outlier.successRateEjection.enforcementPercentage).isEqualTo(81); + assertThat(outlier.successRateEjection.minimumHosts).isEqualTo(107); + assertThat(outlier.successRateEjection.requestVolume).isEqualTo(108); + assertThat(outlier.failurePercentageEjection.threshold).isEqualTo(82); + assertThat(outlier.failurePercentageEjection.enforcementPercentage).isEqualTo(83); + assertThat(outlier.failurePercentageEjection.minimumHosts).isEqualTo(111); + assertThat(outlier.failurePercentageEjection.requestVolume).isEqualTo(112); + assertClusterImplConfig( + (ClusterImplConfig) GracefulSwitchLoadBalancerAccessor.getChildConfig(outlier.childConfig), + CLUSTER, EDS_SERVICE_NAME, null, null, null, Collections.emptyList(), + "wrr_locality_experimental"); } - @Test - public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_returnErrorPicker() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - reset(helper); - Status upstreamError = Status.UNAVAILABLE.withDescription("unreachable"); - loadBalancer.handleNameResolutionError(upstreamError); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), upstreamError, null); + private void startXdsDepManager() { + startXdsDepManager(new CdsConfig(CLUSTER)); } - @Test - public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThrough() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - reset(helper); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint1, 100, true)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); - resolver.deliverEndpointAddresses(Collections.singletonList(endpoint2)); - assertThat(childBalancers).hasSize(1); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(((PriorityLbConfig) childBalancer.config).priorities) - .containsExactly(CLUSTER1 + "[child1]", CLUSTER_DNS + "[child0]"); - assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), childBalancer.addresses); - - loadBalancer.handleNameResolutionError(Status.UNAVAILABLE.withDescription("unreachable")); - assertThat(childBalancer.upstreamError.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(childBalancer.upstreamError.getDescription()).isEqualTo("unreachable"); - verify(helper, never()).updateBalancingState( - any(ConnectivityState.class), any(SubchannelPicker.class)); + private void startXdsDepManager(final CdsConfig cdsConfig) { + startXdsDepManager(cdsConfig, true); } - private void deliverLbConfig(ClusterResolverConfig config) { - loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(Collections.emptyList()) - .setAttributes( - // Other attributes not used by cluster_resolver LB are omitted. - Attributes.newBuilder() - .set(InternalXdsAttributes.XDS_CLIENT_POOL, xdsClientPool) - .build()) - .setLoadBalancingPolicyConfig(config) - .build()); + private void startXdsDepManager(final CdsConfig cdsConfig, boolean forwardTime) { + xdsDepManager.start( + xdsConfig -> { + if (!xdsConfig.hasValue()) { + throw new AssertionError("" + xdsConfig.getStatus()); + } + if (loadBalancer == null) { + return; + } + loadBalancer.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(Collections.emptyList()) + .setAttributes(Attributes.newBuilder() + .set(io.grpc.xds.XdsAttributes.XDS_CONFIG, xdsConfig.getValue()) + .set(io.grpc.xds.XdsAttributes.XDS_CLUSTER_SUBSCRIPT_REGISTRY, xdsDepManager) + .build()) + .setLoadBalancingPolicyConfig(cdsConfig) + .build()); + }); + if (forwardTime) { + // trigger does not exist timer, so broken config is more obvious + fakeClock.forwardTime(10, TimeUnit.MINUTES); + } } private FakeNameResolver assertResolverCreated(String uriPath) { @@ -1159,96 +1082,34 @@ private static void assertClusterImplConfig(ClusterImplConfig config, String clu /** Asserts two list of EAGs contains same addresses, regardless of attributes. */ private static void assertAddressesEqual( List expected, List actual) { - List> expectedAddresses + List> expectedAddresses = expected.stream().map(EquivalentAddressGroup::getAddresses).collect(toList()); - List> actualAddresses + List> actualAddresses = actual.stream().map(EquivalentAddressGroup::getAddresses).collect(toList()); assertThat(actualAddresses).isEqualTo(expectedAddresses); } - private static EquivalentAddressGroup makeAddress(final String name) { - class FakeSocketAddress extends SocketAddress { - private final String name; - - private FakeSocketAddress(String name) { - this.name = name; - } - - @Override - public int hashCode() { - return Objects.hash(name); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof FakeSocketAddress)) { - return false; - } - FakeSocketAddress that = (FakeSocketAddress) o; - return Objects.equals(name, that.name); - } - - @Override - public String toString() { - return name; - } - } - - return new EquivalentAddressGroup(new FakeSocketAddress(name)); + @SuppressWarnings("AddressSelection") + private static InetSocketAddress newInetSocketAddress(String ip, int port) { + return new InetSocketAddress(ip, port); } - private static final class FakeXdsClient extends XdsClient { - private final Map> watchers = new HashMap<>(); - - @Override - @SuppressWarnings("unchecked") - public void watchXdsResource(XdsResourceType type, - String resourceName, - ResourceWatcher watcher, - Executor syncContext) { - assertThat(type.typeName()).isEqualTo("EDS"); - assertThat(watchers).doesNotContainKey(resourceName); - watchers.put(resourceName, (ResourceWatcher) watcher); - } - - @Override - @SuppressWarnings("unchecked") - public void cancelXdsResourceWatch(XdsResourceType type, - String resourceName, - ResourceWatcher watcher) { - assertThat(type.typeName()).isEqualTo("EDS"); - assertThat(watchers).containsKey(resourceName); - watchers.remove(resourceName); - } - - void deliverClusterLoadAssignment( - String resource, Map localityLbEndpointsMap) { - deliverClusterLoadAssignment( - resource, Collections.emptyList(), localityLbEndpointsMap); - } - - void deliverClusterLoadAssignment(String resource, List dropOverloads, - Map localityLbEndpointsMap) { - if (watchers.containsKey(resource)) { - watchers.get(resource).onChanged( - new XdsEndpointResource.EdsUpdate(resource, localityLbEndpointsMap, dropOverloads)); - } - } + private static EquivalentAddressGroup newInetSocketAddressEag(String ip, int port) { + return new EquivalentAddressGroup(newInetSocketAddress(ip, port)); + } - void deliverResourceNotFound(String resource) { - if (watchers.containsKey(resource)) { - watchers.get(resource).onResourceDoesNotExist(resource); - } - } + private static LbEndpoint.Builder newSocketLbEndpoint(String ip, int port) { + return LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(newAddress(ip, port))) + .setHealthStatus(HealthStatus.HEALTHY); + } - void deliverError(Status error) { - for (ResourceWatcher watcher : watchers.values()) { - watcher.onError(error); - } - } + private static Address.Builder newAddress(String ip, int port) { + return Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress(ip) + .setPortValue(port)); } private class FakeNameResolverProvider extends NameResolverProvider { @@ -1276,9 +1137,10 @@ protected int priority() { } } + private class FakeNameResolver extends NameResolver { private final URI targetUri; - private Listener2 listener; + protected Listener2 listener; private int refreshCount; private FakeNameResolver(URI targetUri) { @@ -1305,12 +1167,17 @@ public void shutdown() { resolvers.remove(this); } - private void deliverEndpointAddresses(List addresses) { - listener.onResult(ResolutionResult.newBuilder().setAddresses(addresses).build()); + protected void deliverEndpointAddresses(List addresses) { + syncContext.execute(() -> { + Status ret = listener.onResult2(ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(addresses)).build()); + assertThat(ret.getCode()).isEqualTo(Status.Code.OK); + }); } - private void deliverError(Status error) { - listener.onError(error); + protected void deliverError(Status error) { + syncContext.execute(() -> listener.onResult2(ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromStatus(error)).build())); } } @@ -1349,7 +1216,6 @@ private final class FakeLoadBalancer extends LoadBalancer { private final Helper helper; private List addresses; private Object config; - private Status upstreamError; private boolean shutdown; FakeLoadBalancer(String name, Helper helper) { @@ -1366,7 +1232,6 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { @Override public void handleNameResolutionError(Status error) { - upstreamError = error; } @Override diff --git a/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java b/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java index 1ddf9620434..3665e16b6bf 100644 --- a/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java +++ b/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java @@ -22,7 +22,9 @@ import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; import com.google.protobuf.Message; import com.google.protobuf.UInt32Value; import io.envoyproxy.envoy.config.cluster.v3.Cluster; @@ -55,6 +57,7 @@ import io.grpc.InsecureServerCredentials; import io.grpc.NameResolverRegistry; import io.grpc.Server; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.UUID; @@ -86,9 +89,11 @@ public class ControlPlaneRule extends TestWatcher { private XdsTestControlPlaneService controlPlaneService; private XdsTestLoadReportingService loadReportingService; private XdsNameResolverProvider nameResolverProvider; + private int port; // Only change from 0 to actual port used in the server. public ControlPlaneRule() { serverHostName = "test-server"; + this.port = 0; } public ControlPlaneRule setServerHostName(String serverHostName) { @@ -115,11 +120,7 @@ public Server getServer() { try { controlPlaneService = new XdsTestControlPlaneService(); loadReportingService = new XdsTestLoadReportingService(); - server = Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) - .addService(controlPlaneService) - .addService(loadReportingService) - .build() - .start(); + createAndStartXdsServer(); } catch (Exception e) { throw new AssertionError("unable to start the control plane server", e); } @@ -144,6 +145,42 @@ public Server getServer() { NameResolverRegistry.getDefaultRegistry().deregister(nameResolverProvider); } + /** + * Will shutdown existing server if needed. + * Then creates a new server in the same way as {@link #starting(Description)} and starts it. + */ + public void restartXdsServer() { + + if (getServer() != null && !getServer().isTerminated()) { + getServer().shutdownNow(); + try { + if (!getServer().awaitTermination(5, TimeUnit.SECONDS)) { + logger.log(Level.SEVERE, "Timed out waiting for server shutdown"); + } + } catch (InterruptedException e) { + throw new AssertionError("unable to shut down control plane server", e); + } + } + + try { + createAndStartXdsServer(); + } catch (Exception e) { + throw new AssertionError("unable to restart the control plane server", e); + } + } + + private void createAndStartXdsServer() throws IOException { + server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) + .addService(controlPlaneService) + .addService(loadReportingService) + .build() + .start(); + + if (port == 0) { + port = server.getPort(); + } + } + /** * For test purpose, use boostrapOverride to programmatically provide bootstrap info. */ @@ -159,7 +196,7 @@ public Server getServer() { "channel_creds", Collections.singletonList( ImmutableMap.of("type", "insecure") ), - "server_features", Collections.singletonList("xds_v3") + "server_features", Lists.newArrayList("xds_v3", "trusted_xds_server") ) ), "server_listener_resource_name_template", SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT @@ -173,44 +210,70 @@ void setLdsConfig(Listener serverListener, Listener clientListener) { } void setRdsConfig(RouteConfiguration routeConfiguration) { - getService().setXdsConfig(ADS_TYPE_URL_RDS, ImmutableMap.of(RDS_NAME, routeConfiguration)); + setRdsConfig(RDS_NAME, routeConfiguration); + } + + public void setRdsConfig(String rdsName, RouteConfiguration routeConfiguration) { + getService().setXdsConfig(ADS_TYPE_URL_RDS, ImmutableMap.of(rdsName, routeConfiguration)); } void setCdsConfig(Cluster cluster) { + setCdsConfig(CLUSTER_NAME, cluster); + } + + void setCdsConfig(String clusterName, Cluster cluster) { getService().setXdsConfig(ADS_TYPE_URL_CDS, - ImmutableMap.of(CLUSTER_NAME, cluster)); + ImmutableMap.of(clusterName, cluster)); } void setEdsConfig(ClusterLoadAssignment clusterLoadAssignment) { + setEdsConfig(EDS_NAME, clusterLoadAssignment); + } + + void setEdsConfig(String edsName, ClusterLoadAssignment clusterLoadAssignment) { getService().setXdsConfig(ADS_TYPE_URL_EDS, - ImmutableMap.of(EDS_NAME, clusterLoadAssignment)); + ImmutableMap.of(edsName, clusterLoadAssignment)); } /** * Builds a new default RDS configuration. */ static RouteConfiguration buildRouteConfiguration(String authority) { - io.envoyproxy.envoy.config.route.v3.VirtualHost virtualHost = VirtualHost.newBuilder() - .addDomains(authority) - .addRoutes( - Route.newBuilder() - .setMatch( - RouteMatch.newBuilder().setPrefix("/").build()) - .setRoute( - RouteAction.newBuilder().setCluster(CLUSTER_NAME).build()).build()).build(); - return RouteConfiguration.newBuilder().setName(RDS_NAME).addVirtualHosts(virtualHost).build(); + return buildRouteConfiguration(authority, RDS_NAME, CLUSTER_NAME); + } + + static RouteConfiguration buildRouteConfiguration(String authority, String rdsName, + String clusterName) { + io.envoyproxy.envoy.config.route.v3.VirtualHost.Builder vhBuilder = + io.envoyproxy.envoy.config.route.v3.VirtualHost.newBuilder() + .setName(rdsName) + .addDomains(authority) + .addRoutes( + Route.newBuilder() + .setMatch( + RouteMatch.newBuilder().setPrefix("/").build()) + .setRoute( + RouteAction.newBuilder().setCluster(clusterName) + .setAutoHostRewrite(BoolValue.newBuilder().setValue(true).build()) + .build())); + io.envoyproxy.envoy.config.route.v3.VirtualHost virtualHost = vhBuilder.build(); + return RouteConfiguration.newBuilder().setName(rdsName).addVirtualHosts(virtualHost).build(); } /** * Builds a new default CDS configuration. */ static Cluster buildCluster() { + return buildCluster(CLUSTER_NAME, EDS_NAME); + } + + static Cluster buildCluster(String clusterName, String edsName) { return Cluster.newBuilder() - .setName(CLUSTER_NAME) + .setName(clusterName) .setType(Cluster.DiscoveryType.EDS) .setEdsClusterConfig( Cluster.EdsClusterConfig.newBuilder() - .setServiceName(EDS_NAME) + .setServiceName(edsName) .setEdsConfig( ConfigSource.newBuilder() .setAds(AggregatedConfigSource.newBuilder().build()) @@ -223,21 +286,29 @@ static Cluster buildCluster() { /** * Builds a new default EDS configuration. */ - static ClusterLoadAssignment buildClusterLoadAssignment(String hostName, int port) { + static ClusterLoadAssignment buildClusterLoadAssignment( + String hostAddress, String endpointHostname, int port) { + return buildClusterLoadAssignment(hostAddress, endpointHostname, port, EDS_NAME); + } + + static ClusterLoadAssignment buildClusterLoadAssignment( + String hostAddress, String endpointHostname, int port, String edsName) { + Address address = Address.newBuilder() .setSocketAddress( - SocketAddress.newBuilder().setAddress(hostName).setPortValue(port).build()).build(); + SocketAddress.newBuilder().setAddress(hostAddress).setPortValue(port).build()).build(); LocalityLbEndpoints endpoints = LocalityLbEndpoints.newBuilder() .setLoadBalancingWeight(UInt32Value.of(10)) .setPriority(0) .addLbEndpoints( LbEndpoint.newBuilder() .setEndpoint( - Endpoint.newBuilder().setAddress(address).build()) + Endpoint.newBuilder() + .setAddress(address).setHostname(endpointHostname).build()) .setHealthStatus(HealthStatus.HEALTHY) .build()).build(); return ClusterLoadAssignment.newBuilder() - .setClusterName(EDS_NAME) + .setClusterName(edsName) .addEndpoints(endpoints) .build(); } @@ -246,6 +317,10 @@ static ClusterLoadAssignment buildClusterLoadAssignment(String hostName, int por * Builds a new client listener. */ static Listener buildClientListener(String name) { + return buildClientListener(name, RDS_NAME); + } + + static Listener buildClientListener(String name, String rdsName) { HttpFilter httpFilter = HttpFilter.newBuilder() .setName("terminal-filter") .setTypedConfig(Any.pack(Router.newBuilder().build())) @@ -256,7 +331,7 @@ static Listener buildClientListener(String name) { .HttpConnectionManager.newBuilder() .setRds( Rds.newBuilder() - .setRouteConfigName(RDS_NAME) + .setRouteConfigName(rdsName) .setConfigSource( ConfigSource.newBuilder() .setAds(AggregatedConfigSource.getDefaultInstance()))) @@ -306,10 +381,14 @@ static Listener buildServerListener() { .setFilterChainMatch(filterChainMatch) .addFilters(filter) .build(); + Address address = Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder().setAddress("0.0.0.0").setPortValue(0)) + .build(); return Listener.newBuilder() .setName(SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT) .setTrafficDirection(TrafficDirection.INBOUND) .addFilterChains(filterChain) + .setAddress(address) .build(); } } diff --git a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java index 63b9cda043c..e8bd7461736 100644 --- a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java +++ b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java @@ -39,6 +39,7 @@ import io.envoyproxy.envoy.type.matcher.v3.NodeMatcher; import io.grpc.Deadline; import io.grpc.InsecureChannelCredentials; +import io.grpc.MetricRecorder; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; @@ -107,7 +108,7 @@ public void setUp() { // because true->false return mutation prevents fetchClientStatus from completing the request. csdsStub = ClientStatusDiscoveryServiceGrpc .newBlockingStub(grpcServerRule.getChannel()) - .withDeadline(Deadline.after(3, TimeUnit.SECONDS)); + .withDeadline(Deadline.after(30, TimeUnit.SECONDS)); csdsAsyncStub = ClientStatusDiscoveryServiceGrpc.newStub(grpcServerRule.getChannel()); } @@ -298,7 +299,7 @@ private void verifyResponse(ClientStatusResponse response) { assertThat(response.getConfigCount()).isEqualTo(1); ClientConfig clientConfig = response.getConfig(0); verifyClientConfigNode(clientConfig); - verifyClientConfigNoResources(XDS_CLIENT_NO_RESOURCES, clientConfig); + assertThat(clientConfig.getGenericXdsConfigsList()).isEmpty(); assertThat(clientConfig.getClientScope()).isEmpty(); } @@ -309,7 +310,7 @@ private Collection verifyMultiResponse(ClientStatusResponse response, in for (int i = 0; i < numExpected; i++) { ClientConfig clientConfig = response.getConfig(i); verifyClientConfigNode(clientConfig); - verifyClientConfigNoResources(XDS_CLIENT_NO_RESOURCES, clientConfig); + assertThat(clientConfig.getGenericXdsConfigsList()).isEmpty(); clientScopes.add(clientConfig.getClientScope()); } @@ -365,6 +366,8 @@ public void metadataStatusToClientStatus() { .isEqualTo(ClientResourceStatus.ACKED); assertThat(CsdsService.metadataStatusToClientStatus(ResourceMetadataStatus.NACKED)) .isEqualTo(ClientResourceStatus.NACKED); + assertThat(CsdsService.metadataStatusToClientStatus(ResourceMetadataStatus.TIMEOUT)) + .isEqualTo(ClientResourceStatus.TIMEOUT); } @Test @@ -381,16 +384,6 @@ public void getClientConfigForXdsClient_subscribedResourcesToGenericXdsConfig() .put(EDS, ImmutableMap.of("subscribedResourceName.EDS", METADATA_ACKED_EDS)) .buildOrThrow(); } - - @Override - public Map> getSubscribedResourceTypesWithTypeUrl() { - return ImmutableMap.of( - LDS.typeUrl(), LDS, - RDS.typeUrl(), RDS, - CDS.typeUrl(), CDS, - EDS.typeUrl(), EDS - ); - } }; ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(fakeXdsClient, FAKE_CLIENT_SCOPE); @@ -402,31 +395,31 @@ public Map> getSubscribedResourceTypesWithTypeUrl() { // is propagated to the correct resource types. int xdsConfigCount = clientConfig.getGenericXdsConfigsCount(); assertThat(xdsConfigCount).isEqualTo(4); - Map, GenericXdsConfig> configDumps = mapConfigDumps(fakeXdsClient, - clientConfig); - assertThat(configDumps.keySet()).containsExactly(LDS, RDS, CDS, EDS); + Map configDumps = mapConfigDumps(clientConfig); + assertThat(configDumps.keySet()) + .containsExactly(LDS.typeUrl(), RDS.typeUrl(), CDS.typeUrl(), EDS.typeUrl()); // LDS. - GenericXdsConfig genericXdsConfigLds = configDumps.get(LDS); + GenericXdsConfig genericXdsConfigLds = configDumps.get(LDS.typeUrl()); assertThat(genericXdsConfigLds.getName()).isEqualTo("subscribedResourceName.LDS"); assertThat(genericXdsConfigLds.getClientStatus()).isEqualTo(ClientResourceStatus.ACKED); assertThat(genericXdsConfigLds.getVersionInfo()).isEqualTo(VERSION_ACK_LDS); assertThat(genericXdsConfigLds.getXdsConfig()).isEqualTo(RAW_LISTENER); // RDS. - GenericXdsConfig genericXdsConfigRds = configDumps.get(RDS); + GenericXdsConfig genericXdsConfigRds = configDumps.get(RDS.typeUrl()); assertThat(genericXdsConfigRds.getClientStatus()).isEqualTo(ClientResourceStatus.ACKED); assertThat(genericXdsConfigRds.getVersionInfo()).isEqualTo(VERSION_ACK_RDS); assertThat(genericXdsConfigRds.getXdsConfig()).isEqualTo(RAW_ROUTE_CONFIGURATION); // CDS. - GenericXdsConfig genericXdsConfigCds = configDumps.get(CDS); + GenericXdsConfig genericXdsConfigCds = configDumps.get(CDS.typeUrl()); assertThat(genericXdsConfigCds.getClientStatus()).isEqualTo(ClientResourceStatus.ACKED); assertThat(genericXdsConfigCds.getVersionInfo()).isEqualTo(VERSION_ACK_CDS); assertThat(genericXdsConfigCds.getXdsConfig()).isEqualTo(RAW_CLUSTER); // RDS. - GenericXdsConfig genericXdsConfigEds = configDumps.get(EDS); + GenericXdsConfig genericXdsConfigEds = configDumps.get(EDS.typeUrl()); assertThat(genericXdsConfigEds.getClientStatus()).isEqualTo(ClientResourceStatus.ACKED); assertThat(genericXdsConfigEds.getVersionInfo()).isEqualTo(VERSION_ACK_EDS); assertThat(genericXdsConfigEds.getXdsConfig()).isEqualTo(RAW_CLUSTER_LOAD_ASSIGNMENT); @@ -437,23 +430,11 @@ public void getClientConfigForXdsClient_noSubscribedResources() throws Interrupt ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(XDS_CLIENT_NO_RESOURCES, FAKE_CLIENT_SCOPE); verifyClientConfigNode(clientConfig); - verifyClientConfigNoResources(XDS_CLIENT_NO_RESOURCES, clientConfig); + assertThat(clientConfig.getGenericXdsConfigsList()).isEmpty(); assertThat(clientConfig.getClientScope()).isEqualTo(FAKE_CLIENT_SCOPE); } } - /** - * Assuming {@link MetadataToProtoTests} passes, and metadata converted to corresponding - * config dumps correctly, perform a minimal verification of the general shape of ClientConfig. - */ - private static void verifyClientConfigNoResources(FakeXdsClient xdsClient, - ClientConfig clientConfig) { - int xdsConfigCount = clientConfig.getGenericXdsConfigsCount(); - assertThat(xdsConfigCount).isEqualTo(0); - Map, GenericXdsConfig> configDumps = mapConfigDumps(xdsClient, clientConfig); - assertThat(configDumps).isEmpty(); - } - /** * Assuming {@link EnvoyProtoDataTest#convertNode} passes, perform a minimal check, * just verify the node itself is the one we expect. @@ -464,21 +445,17 @@ private static void verifyClientConfigNode(ClientConfig clientConfig) { assertThat(node).isEqualTo(BOOTSTRAP_NODE.toEnvoyProtoNode()); } - private static Map, GenericXdsConfig> mapConfigDumps(FakeXdsClient client, - ClientConfig config) { - Map, GenericXdsConfig> xdsConfigMap = new HashMap<>(); + private static Map mapConfigDumps(ClientConfig config) { + Map xdsConfigMap = new HashMap<>(); List xdsConfigList = config.getGenericXdsConfigsList(); for (GenericXdsConfig genericXdsConfig : xdsConfigList) { - XdsResourceType type = client.getSubscribedResourceTypesWithTypeUrl() - .get(genericXdsConfig.getTypeUrl()); - assertThat(type).isNotNull(); - assertThat(xdsConfigMap).doesNotContainKey(type); - xdsConfigMap.put(type, genericXdsConfig); + assertThat(xdsConfigMap).doesNotContainKey(genericXdsConfig.getTypeUrl()); + xdsConfigMap.put(genericXdsConfig.getTypeUrl(), genericXdsConfig); } return xdsConfigMap; } - private static class FakeXdsClient extends XdsClient implements XdsClient.ResourceStore { + private static class FakeXdsClient extends XdsClient { protected Map, Map> getSubscribedResourcesMetadata() { return ImmutableMap.of(); @@ -494,25 +471,11 @@ private static class FakeXdsClient extends XdsClient implements XdsClient.Resour public BootstrapInfo getBootstrapInfo() { return BOOTSTRAP_INFO; } - - @Nullable - @Override - public Collection getSubscribedResources(ServerInfo serverInfo, - XdsResourceType type) { - return null; - } - - @Override - public Map> getSubscribedResourceTypesWithTypeUrl() { - return ImmutableMap.of(); - } - } private static class FakeXdsClientPoolFactory implements XdsClientPoolFactory { private final Map xdsClientMap = new HashMap<>(); - private boolean isOldStyle - ; + private boolean isOldStyle; private FakeXdsClientPoolFactory(@Nullable XdsClient xdsClient) { if (xdsClient != null) { @@ -550,12 +513,8 @@ public List getTargets() { } @Override - public void setBootstrapOverride(Map bootstrap) { - throw new UnsupportedOperationException("Should not be called"); - } - - @Override - public ObjectPool getOrCreate(String target) { + public ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder) { throw new UnsupportedOperationException("Should not be called"); } } diff --git a/xds/src/test/java/io/grpc/xds/DataPlaneRule.java b/xds/src/test/java/io/grpc/xds/DataPlaneRule.java index faa79444071..b308419d142 100644 --- a/xds/src/test/java/io/grpc/xds/DataPlaneRule.java +++ b/xds/src/test/java/io/grpc/xds/DataPlaneRule.java @@ -48,6 +48,7 @@ public class DataPlaneRule extends TestWatcher { private static final Logger logger = Logger.getLogger(DataPlaneRule.class.getName()); private static final String SERVER_HOST_NAME = "test-server"; + static final String ENDPOINT_HOST_NAME = "endpoint-host-name"; private static final String SCHEME = "test-xds"; private final ControlPlaneRule controlPlane; @@ -73,7 +74,8 @@ public Server getServer() { */ public ManagedChannel getManagedChannel() { ManagedChannel channel = Grpc.newChannelBuilder(SCHEME + ":///" + SERVER_HOST_NAME, - InsecureChannelCredentials.create()).build(); + InsecureChannelCredentials.create()) + .build(); channels.add(channel); return channel; } @@ -98,7 +100,7 @@ protected void starting(Description description) { InetSocketAddress edsInetSocketAddress = (InetSocketAddress) server.getListenSockets().get(0); controlPlane.setEdsConfig( ControlPlaneRule.buildClusterLoadAssignment(edsInetSocketAddress.getHostName(), - edsInetSocketAddress.getPort())); + ENDPOINT_HOST_NAME, edsInetSocketAddress.getPort())); } @Override @@ -124,10 +126,12 @@ protected void finished(Description description) { } private void startServer(Map bootstrapOverride) throws Exception { + final String[] authority = new String[1]; ServerInterceptor metadataInterceptor = new ServerInterceptor() { @Override public ServerCall.Listener interceptCall(ServerCall call, Metadata requestHeaders, ServerCallHandler next) { + authority[0] = call.getAuthority(); logger.fine("Received following metadata: " + requestHeaders); // Make a copy of the headers so that it can be read in a thread-safe manner when copying @@ -155,8 +159,12 @@ public void close(Status status, Metadata trailers) { @Override public void unaryRpc( SimpleRequest request, StreamObserver responseObserver) { + String responseMsg = "Hi, xDS!"; + if (authority[0] != null) { + responseMsg += " Authority= " + authority[0]; + } SimpleResponse response = - SimpleResponse.newBuilder().setResponseMessage("Hi, xDS!").build(); + SimpleResponse.newBuilder().setResponseMessage(responseMsg).build(); responseObserver.onNext(response); responseObserver.onCompleted(); } diff --git a/xds/src/test/java/io/grpc/xds/ExtAuthzConfigParserTest.java b/xds/src/test/java/io/grpc/xds/ExtAuthzConfigParserTest.java new file mode 100644 index 00000000000..fa2718cbe63 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/ExtAuthzConfigParserTest.java @@ -0,0 +1,297 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.RuntimeFeatureFlag; +import io.envoyproxy.envoy.config.core.v3.RuntimeFractionalPercent; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.type.matcher.v3.ListStringMatcher; +import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.envoyproxy.envoy.type.v3.FractionalPercent; +import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; +import io.grpc.Status; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.EnvoyProtoData.Node; +import io.grpc.xds.internal.Matchers; +import io.grpc.xds.internal.extauthz.ExtAuthzConfig; +import io.grpc.xds.internal.extauthz.ExtAuthzParseException; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesConfig; +import java.util.Collections; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ExtAuthzConfigParserTest { + + private static final Any GOOGLE_DEFAULT_CHANNEL_CREDS = + Any.pack(GoogleDefaultCredentials.newBuilder().build()); + private static final Any FAKE_ACCESS_TOKEN_CALL_CREDS = + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); + + private static BootstrapInfo dummyBootstrapInfo() { + return BootstrapInfo.builder() + .servers( + Collections.singletonList(ServerInfo.create("test_target", Collections.emptyMap()))) + .node(Node.newBuilder().build()).build(); + } + + private static ServerInfo dummyServerInfo() { + return ServerInfo.create("test_target", Collections.emptyMap(), false, true, false, false); + } + + private ExtAuthz.Builder extAuthzBuilder; + + @Before + public void setUp() { + extAuthzBuilder = ExtAuthz.newBuilder() + .setGrpcService(GrpcService.newBuilder().setGoogleGrpc(GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test-cluster") + .addChannelCredentialsPlugin(GOOGLE_DEFAULT_CHANNEL_CREDS) + .addCallCredentialsPlugin(FAKE_ACCESS_TOKEN_CALL_CREDS).build()) + .build()); + } + + @Test + public void parse_missingGrpcService_throws() { + ExtAuthz extAuthz = ExtAuthz.newBuilder().build(); + try { + ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat() + .isEqualTo("unsupported ExtAuthz service type: only grpc_service is supported"); + } + } + + @Test + public void parse_invalidGrpcService_throws() { + ExtAuthz extAuthz = ExtAuthz.newBuilder() + .setGrpcService(GrpcService.newBuilder().build()) + .build(); + try { + ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().startsWith("Failed to parse GrpcService config:"); + } + } + + @Test + public void parse_invalidAllowExpression_throws() { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("[invalid").build()).build()) + .build(); + try { + ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().startsWith("Invalid regex pattern for allow_expression:"); + } + } + + @Test + public void parse_invalidDisallowExpression_throws() { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("[invalid").build()).build()) + .build(); + try { + ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().startsWith("Invalid regex pattern for disallow_expression:"); + } + } + + @Test + public void parse_success() throws ExtAuthzParseException { + ExtAuthz extAuthz = + extAuthzBuilder + .setGrpcService(extAuthzBuilder.getGrpcServiceBuilder() + .setTimeout(com.google.protobuf.Duration.newBuilder().setSeconds(5).build()) + .addInitialMetadata( + HeaderValue.newBuilder().setKey("key").setValue("value").build()) + .build()) + .setFailureModeAllow(true).setFailureModeAllowHeaderAdd(true) + .setIncludePeerCertificate(true) + .setStatusOnError( + io.envoyproxy.envoy.type.v3.HttpStatus.newBuilder().setCodeValue(403).build()) + .setDenyAtDisable( + RuntimeFeatureFlag.newBuilder().setDefaultValue(BoolValue.of(true)).build()) + .setFilterEnabled(RuntimeFractionalPercent.newBuilder() + .setDefaultValue(FractionalPercent.newBuilder().setNumerator(50) + .setDenominator(DenominatorType.TEN_THOUSAND).build()) + .build()) + .setAllowedHeaders(ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("allowed-header").build()).build()) + .setDisallowedHeaders(ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setPrefix("disallowed-").build()).build()) + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("allow.*").build()) + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("disallow.*").build()) + .setDisallowAll(BoolValue.of(true)).setDisallowIsError(BoolValue.of(true)).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.grpcService().googleGrpc().target()).isEqualTo("test-cluster"); + assertThat(config.grpcService().timeout().get().getSeconds()).isEqualTo(5); + assertThat(config.grpcService().initialMetadata()).isNotEmpty(); + assertThat(config.failureModeAllow()).isTrue(); + assertThat(config.failureModeAllowHeaderAdd()).isTrue(); + assertThat(config.includePeerCertificate()).isTrue(); + assertThat(config.statusOnError().getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(config.statusOnError().getDescription()).isEqualTo("HTTP status code 403"); + assertThat(config.denyAtDisable()).isTrue(); + assertThat(config.filterEnabled()).isEqualTo(Matchers.FractionMatcher.create(50, 10_000)); + assertThat(config.allowedHeaders()).hasSize(1); + assertThat(config.allowedHeaders().get(0).matches("allowed-header")).isTrue(); + assertThat(config.disallowedHeaders()).hasSize(1); + assertThat(config.disallowedHeaders().get(0).matches("disallowed-foo")).isTrue(); + assertThat(config.decoderHeaderMutationRules().isPresent()).isTrue(); + HeaderMutationRulesConfig rules = config.decoderHeaderMutationRules().get(); + assertThat(rules.allowExpression().get().pattern()).isEqualTo("allow.*"); + assertThat(rules.disallowExpression().get().pattern()).isEqualTo("disallow.*"); + assertThat(rules.disallowAll()).isTrue(); + assertThat(rules.disallowIsError()).isTrue(); + } + + @Test + public void parse_saneDefaults() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder.build(); + + ExtAuthzConfig config = ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.failureModeAllow()).isFalse(); + assertThat(config.failureModeAllowHeaderAdd()).isFalse(); + assertThat(config.includePeerCertificate()).isFalse(); + assertThat(config.statusOnError()).isEqualTo(Status.PERMISSION_DENIED); + assertThat(config.denyAtDisable()).isFalse(); + assertThat(config.filterEnabled()).isEqualTo(Matchers.FractionMatcher.create(100, 100)); + assertThat(config.allowedHeaders()).isEmpty(); + assertThat(config.disallowedHeaders()).isEmpty(); + assertThat(config.decoderHeaderMutationRules().isPresent()).isFalse(); + } + + @Test + public void parse_headerMutationRules_allowExpressionOnly() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("allow.*").build()).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.decoderHeaderMutationRules().isPresent()).isTrue(); + HeaderMutationRulesConfig rules = config.decoderHeaderMutationRules().get(); + assertThat(rules.allowExpression().get().pattern()).isEqualTo("allow.*"); + assertThat(rules.disallowExpression().isPresent()).isFalse(); + } + + @Test + public void parse_headerMutationRules_disallowExpressionOnly() throws ExtAuthzParseException { + ExtAuthz extAuthz = + extAuthzBuilder.setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("disallow.*").build()) + .build()).build(); + + ExtAuthzConfig config = ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.decoderHeaderMutationRules().isPresent()).isTrue(); + HeaderMutationRulesConfig rules = config.decoderHeaderMutationRules().get(); + assertThat(rules.allowExpression().isPresent()).isFalse(); + assertThat(rules.disallowExpression().get().pattern()).isEqualTo("disallow.*"); + } + + @Test + public void parse_filterEnabled_hundred() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setFilterEnabled(RuntimeFractionalPercent.newBuilder().setDefaultValue(FractionalPercent + .newBuilder().setNumerator(25).setDenominator(DenominatorType.HUNDRED).build()).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.filterEnabled()).isEqualTo(Matchers.FractionMatcher.create(25, 100)); + } + + @Test + public void parse_filterEnabled_million() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setFilterEnabled( + RuntimeFractionalPercent.newBuilder().setDefaultValue(FractionalPercent.newBuilder() + .setNumerator(123456).setDenominator(DenominatorType.MILLION).build()).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.filterEnabled()) + .isEqualTo(Matchers.FractionMatcher.create(123456, 1_000_000)); + } + + @Test + public void parse_filterEnabled_unrecognizedDenominator() { + ExtAuthz extAuthz = extAuthzBuilder.setFilterEnabled(RuntimeFractionalPercent.newBuilder() + .setDefaultValue( + FractionalPercent.newBuilder().setNumerator(1).setDenominatorValue(4).build()) + .build()).build(); + + try { + ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().isEqualTo("Unknown denominator type: UNRECOGNIZED"); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java index 30c2403396e..a273c6f3ebf 100644 --- a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java +++ b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java @@ -18,9 +18,13 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.DataPlaneRule.ENDPOINT_HOST_NAME; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; import static org.junit.Assert.assertEquals; import com.github.xds.type.v3.TypedStruct; +import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; import com.google.protobuf.Struct; import com.google.protobuf.Value; @@ -35,13 +39,21 @@ import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; +import io.envoyproxy.envoy.config.route.v3.Route; +import io.envoyproxy.envoy.config.route.v3.RouteAction; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; +import io.envoyproxy.envoy.config.route.v3.RouteMatch; +import io.envoyproxy.envoy.config.route.v3.VirtualHost; import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; +import io.grpc.ClientStreamTracer; +import io.grpc.FlagResetRule; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener; +import io.grpc.InternalFeatureFlags; import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; import io.grpc.Metadata; @@ -50,10 +62,14 @@ import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; import java.net.InetSocketAddress; +import java.util.Arrays; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** * Xds integration tests using a local control plane, implemented in {@link @@ -75,27 +91,58 @@ * 3) Construct EDS config w/ test server address from 2). Set CDS and EDS Config at the Control * Plane. Then start the test xDS client (requires EDS to do xDS name resolution). */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class FakeControlPlaneXdsIntegrationTest { @Rule(order = 0) public ControlPlaneRule controlPlane = new ControlPlaneRule(); @Rule(order = 1) public DataPlaneRule dataPlane = new DataPlaneRule(controlPlane); + @Rule(order = 2) + public final FlagResetRule flagResetRule = new FlagResetRule(); + + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + + @Before + public void setupRfc3986UrisFeatureFlag() throws Exception { + flagResetRule.setFlagForTest( + InternalFeatureFlags::setRfc3986UrisEnabled, enableRfc3986UrisParam); + } @Test public void pingPong() throws Exception { ManagedChannel channel = dataPlane.getManagedChannel(); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( channel); - SimpleRequest request = SimpleRequest.newBuilder() - .build(); + SimpleRequest request = SimpleRequest.getDefaultInstance(); SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setResponseMessage("Hi, xDS!") + .setResponseMessage("Hi, xDS! Authority= test-server") .build(); assertEquals(goldenResponse, blockingStub.unaryRpc(request)); } + @Test + public void pingPong_edsEndpoint_authorityOverride() throws Exception { + System.setProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", "true"); + try { + ManagedChannel channel = dataPlane.getManagedChannel(); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( + channel); + SimpleRequest request = SimpleRequest.getDefaultInstance(); + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS! Authority= " + ENDPOINT_HOST_NAME) + .build(); + assertEquals(goldenResponse, blockingStub.unaryRpc(request)); + } finally { + System.clearProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE"); + } + } + @Test public void pingPong_metadataLoadBalancer() throws Exception { MetadataLoadBalancerProvider metadataLbProvider = new MetadataLoadBalancerProvider(); @@ -126,10 +173,9 @@ public void pingPong_metadataLoadBalancer() throws Exception { // We add an interceptor to catch the response headers from the server. SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( dataPlane.getManagedChannel()).withInterceptors(responseHeaderInterceptor); - SimpleRequest request = SimpleRequest.newBuilder() - .build(); + SimpleRequest request = SimpleRequest.getDefaultInstance(); SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setResponseMessage("Hi, xDS!") + .setResponseMessage("Hi, xDS! Authority= test-server") .build(); assertEquals(goldenResponse, blockingStub.unaryRpc(request)); @@ -141,6 +187,100 @@ public void pingPong_metadataLoadBalancer() throws Exception { } } + // Try to trigger "UNAVAILABLE: CDS encountered error: unable to find available subchannel for + // cluster cluster:cluster1" race, if XdsNameResolver updates its ConfigSelector before + // cluster_manager config. + @Test + public void changeClusterForRoute() throws Exception { + // Start with route to cluster0 + InetSocketAddress edsInetSocketAddress + = (InetSocketAddress) dataPlane.getServer().getListenSockets().get(0); + controlPlane.getService().setXdsConfig( + ADS_TYPE_URL_EDS, + ImmutableMap.of( + "eds-service-0", + ControlPlaneRule.buildClusterLoadAssignment( + edsInetSocketAddress.getHostName(), "", edsInetSocketAddress.getPort(), + "eds-service-0"), + "eds-service-1", + ControlPlaneRule.buildClusterLoadAssignment( + edsInetSocketAddress.getHostName(), "", edsInetSocketAddress.getPort(), + "eds-service-1"))); + controlPlane.getService().setXdsConfig( + ADS_TYPE_URL_CDS, + ImmutableMap.of( + "cluster0", + ControlPlaneRule.buildCluster("cluster0", "eds-service-0"), + "cluster1", + ControlPlaneRule.buildCluster("cluster1", "eds-service-1"))); + controlPlane.setRdsConfig(RouteConfiguration.newBuilder() + .setName("route-config.googleapis.com") + .addVirtualHosts(VirtualHost.newBuilder() + .addDomains("test-server") + .addRoutes(Route.newBuilder() + .setMatch(RouteMatch.newBuilder().setPrefix("/").build()) + .setRoute(RouteAction.newBuilder().setCluster("cluster0").build()) + .build()) + .build()) + .build()); + + class ClusterClientStreamTracer extends ClientStreamTracer { + boolean usedCluster1; + + @Override + public void addOptionalLabel(String key, String value) { + if ("grpc.lb.backend_service".equals(key)) { + usedCluster1 = "cluster1".equals(value); + } + } + } + + ClusterClientStreamTracer tracer = new ClusterClientStreamTracer(); + ClientStreamTracer.Factory tracerFactory = new ClientStreamTracer.Factory() { + @Override + public ClientStreamTracer newClientStreamTracer( + ClientStreamTracer.StreamInfo info, Metadata headers) { + return tracer; + } + }; + ClientInterceptor tracerInterceptor = new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); + } + }; + SimpleServiceGrpc.SimpleServiceBlockingStub stub = SimpleServiceGrpc + .newBlockingStub(dataPlane.getManagedChannel()) + .withInterceptors(tracerInterceptor); + SimpleRequest request = SimpleRequest.getDefaultInstance(); + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS! Authority= test-server") + .build(); + assertThat(stub.unaryRpc(request)).isEqualTo(goldenResponse); + assertThat(tracer.usedCluster1).isFalse(); + + // Check for errors when swapping route to cluster1 + controlPlane.setRdsConfig(RouteConfiguration.newBuilder() + .setName("route-config.googleapis.com") + .addVirtualHosts(VirtualHost.newBuilder() + .addDomains("test-server") + .addRoutes(Route.newBuilder() + .setMatch(RouteMatch.newBuilder().setPrefix("/").build()) + .setRoute(RouteAction.newBuilder().setCluster("cluster1").build()) + .build()) + .build()) + .build()); + + for (int j = 0; j < 10; j++) { + stub.unaryRpc(request); + if (tracer.usedCluster1) { + break; + } + } + assertThat(tracer.usedCluster1).isTrue(); + } + // Captures response headers from the server. private static class ResponseHeaderClientInterceptor implements ClientInterceptor { Metadata reponseHeaders; @@ -180,41 +320,44 @@ public void pingPong_ringHash() { ManagedChannel channel = dataPlane.getManagedChannel(); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( channel); - SimpleRequest request = SimpleRequest.newBuilder() - .build(); + SimpleRequest request = SimpleRequest.getDefaultInstance(); SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setResponseMessage("Hi, xDS!") + .setResponseMessage("Hi, xDS! Authority= test-server") .build(); assertEquals(goldenResponse, blockingStub.unaryRpc(request)); } @Test - public void pingPong_logicalDns() { - InetSocketAddress serverAddress = - (InetSocketAddress) dataPlane.getServer().getListenSockets().get(0); - controlPlane.setCdsConfig( - ControlPlaneRule.buildCluster().toBuilder() - .setType(Cluster.DiscoveryType.LOGICAL_DNS) - .setLoadAssignment( - ClusterLoadAssignment.newBuilder().addEndpoints( - LocalityLbEndpoints.newBuilder().addLbEndpoints( - LbEndpoint.newBuilder().setEndpoint( - Endpoint.newBuilder().setAddress( - Address.newBuilder().setSocketAddress( - SocketAddress.newBuilder() - .setAddress("localhost") - .setPortValue(serverAddress.getPort())))))) - .build()) - .build()); + public void pingPong_logicalDns_authorityOverride() { + System.setProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", "true"); + try { + InetSocketAddress serverAddress = + (InetSocketAddress) dataPlane.getServer().getListenSockets().get(0); + controlPlane.setCdsConfig( + ControlPlaneRule.buildCluster().toBuilder() + .setType(Cluster.DiscoveryType.LOGICAL_DNS) + .setLoadAssignment( + ClusterLoadAssignment.newBuilder().addEndpoints( + LocalityLbEndpoints.newBuilder().addLbEndpoints( + LbEndpoint.newBuilder().setEndpoint( + Endpoint.newBuilder().setAddress( + Address.newBuilder().setSocketAddress( + SocketAddress.newBuilder() + .setAddress("localhost") + .setPortValue(serverAddress.getPort())))))) + .build()) + .build()); - ManagedChannel channel = dataPlane.getManagedChannel(); - SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( - channel); - SimpleRequest request = SimpleRequest.newBuilder() - .build(); - SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setResponseMessage("Hi, xDS!") - .build(); - assertEquals(goldenResponse, blockingStub.unaryRpc(request)); + ManagedChannel channel = dataPlane.getManagedChannel(); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( + channel); + SimpleRequest request = SimpleRequest.getDefaultInstance(); + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS! Authority= localhost:" + serverAddress.getPort()) + .build(); + assertEquals(goldenResponse, blockingStub.unaryRpc(request)); + } finally { + System.clearProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE"); + } } } diff --git a/xds/src/test/java/io/grpc/xds/FaultFilterTest.java b/xds/src/test/java/io/grpc/xds/FaultFilterTest.java index f85f29ec0a3..8f0a33951b0 100644 --- a/xds/src/test/java/io/grpc/xds/FaultFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/FaultFilterTest.java @@ -33,16 +33,23 @@ /** Tests for {@link FaultFilter}. */ @RunWith(JUnit4.class) public class FaultFilterTest { + private static final FaultFilter.Provider FILTER_PROVIDER = new FaultFilter.Provider(); + + @Test + public void filterType_clientOnly() { + assertThat(FILTER_PROVIDER.isClientFilter()).isTrue(); + assertThat(FILTER_PROVIDER.isServerFilter()).isFalse(); + } @Test public void parseFaultAbort_convertHttpStatus() { Any rawConfig = Any.pack( HTTPFault.newBuilder().setAbort(FaultAbort.newBuilder().setHttpStatus(404)).build()); - FaultConfig faultConfig = FaultFilter.INSTANCE.parseFilterConfig(rawConfig).config; + FaultConfig faultConfig = FILTER_PROVIDER.parseFilterConfig(rawConfig).config; assertThat(faultConfig.faultAbort().status().getCode()) .isEqualTo(GrpcUtil.httpStatusToGrpcStatus(404).getCode()); - FaultConfig faultConfigOverride = - FaultFilter.INSTANCE.parseFilterConfigOverride(rawConfig).config; + + FaultConfig faultConfigOverride = FILTER_PROVIDER.parseFilterConfigOverride(rawConfig).config; assertThat(faultConfigOverride.faultAbort().status().getCode()) .isEqualTo(GrpcUtil.httpStatusToGrpcStatus(404).getCode()); } @@ -54,7 +61,7 @@ public void parseFaultAbort_withHeaderAbort() { .setPercentage(FractionalPercent.newBuilder() .setNumerator(20).setDenominator(DenominatorType.HUNDRED)) .setHeaderAbort(HeaderAbort.getDefaultInstance()).build(); - FaultConfig.FaultAbort faultAbort = FaultFilter.parseFaultAbort(proto).config; + FaultConfig.FaultAbort faultAbort = FaultFilter.Provider.parseFaultAbort(proto).config; assertThat(faultAbort.headerAbort()).isTrue(); assertThat(faultAbort.percent().numerator()).isEqualTo(20); assertThat(faultAbort.percent().denominatorType()) @@ -68,7 +75,7 @@ public void parseFaultAbort_withHttpStatus() { .setPercentage(FractionalPercent.newBuilder() .setNumerator(100).setDenominator(DenominatorType.TEN_THOUSAND)) .setHttpStatus(400).build(); - FaultConfig.FaultAbort res = FaultFilter.parseFaultAbort(proto).config; + FaultConfig.FaultAbort res = FaultFilter.Provider.parseFaultAbort(proto).config; assertThat(res.percent().numerator()).isEqualTo(100); assertThat(res.percent().denominatorType()) .isEqualTo(FaultConfig.FractionalPercent.DenominatorType.TEN_THOUSAND); @@ -82,7 +89,7 @@ public void parseFaultAbort_withGrpcStatus() { .setPercentage(FractionalPercent.newBuilder() .setNumerator(600).setDenominator(DenominatorType.MILLION)) .setGrpcStatus(Code.DEADLINE_EXCEEDED.value()).build(); - FaultConfig.FaultAbort faultAbort = FaultFilter.parseFaultAbort(proto).config; + FaultConfig.FaultAbort faultAbort = FaultFilter.Provider.parseFaultAbort(proto).config; assertThat(faultAbort.percent().numerator()).isEqualTo(600); assertThat(faultAbort.percent().denominatorType()) .isEqualTo(FaultConfig.FractionalPercent.DenominatorType.MILLION); diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java index 685102477cc..722f915dbea 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.net.InetAddresses; import com.google.common.util.concurrent.SettableFuture; import io.grpc.ServerInterceptor; import io.grpc.internal.TestUtils.NoopChannelLogger; @@ -58,7 +59,6 @@ import io.netty.handler.codec.http2.Http2Settings; import java.net.InetSocketAddress; import java.net.SocketAddress; -import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -318,7 +318,8 @@ public void destPrefixRangeMatch() throws Exception { EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMatch = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.1.2.0", 24)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 24)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -360,7 +361,8 @@ public void destPrefixRangeMismatch_returnDefaultFilterChain() EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMismatch = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.2.2.0", 24)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.2.2.0"), 24)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -403,7 +405,8 @@ public void dest0LengthPrefixRange() EnvoyServerProtoData.FilterChainMatch filterChainMatch0Length = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.2.2.0", 0)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.2.2.0"), 0)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -444,7 +447,8 @@ public void destPrefixRange_moreSpecificWins() EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.1.2.0", 24)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 24)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -461,7 +465,8 @@ public void destPrefixRange_moreSpecificWins() EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.1.2.2", 31)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.2"), 31)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -519,7 +524,8 @@ public void destPrefixRange_emptyListLessSpecific() EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("8.0.0.0", 5)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("8.0.0.0"), 5)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -559,7 +565,8 @@ public void destPrefixRangeIpv6_moreSpecificWins() EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("FE80:0:0:0:0:0:0:0", 60)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("FE80:0:0:0:0:0:0:0"), 60)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -577,7 +584,8 @@ public void destPrefixRangeIpv6_moreSpecificWins() EnvoyServerProtoData.FilterChainMatch.create( 0, ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("FE80:0000:0000:0000:0202:0:0:0", 80)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("FE80:0000:0000:0000:0202:0:0:0"), 80)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -620,8 +628,10 @@ public void destPrefixRange_moreSpecificWith2Wins() EnvoyServerProtoData.FilterChainMatch.create( 0, ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.1.2.0", 24), - EnvoyServerProtoData.CidrRange.create(LOCAL_IP, 32)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 24), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString(LOCAL_IP), 32)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -638,7 +648,8 @@ public void destPrefixRange_moreSpecificWith2Wins() EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.1.2.2", 31)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.2"), 31)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -763,8 +774,10 @@ public void sourcePrefixRange_moreSpecificWith2Wins() ImmutableList.of(), ImmutableList.of(), ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.4.2.0", 24), - EnvoyServerProtoData.CidrRange.create(REMOTE_IP, 32)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.0"), 24), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString(REMOTE_IP), 32)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -781,7 +794,8 @@ public void sourcePrefixRange_moreSpecificWith2Wins() 0, ImmutableList.of(), ImmutableList.of(), - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.4.2.2", 31)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.2"), 31)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -811,8 +825,7 @@ filterChainLessSpecific, randomConfig("no-match")), } @Test - public void sourcePrefixRange_2Matchers_expectException() - throws UnknownHostException { + public void sourcePrefixRange_2Matchers_expectException() { ChannelHandler next = new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { @@ -831,8 +844,10 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { ImmutableList.of(), ImmutableList.of(), ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.4.2.0", 24), - EnvoyServerProtoData.CidrRange.create("192.168.10.2", 32)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.0"), 24), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("192.168.10.2"), 32)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -848,7 +863,8 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { 0, ImmutableList.of(), ImmutableList.of(), - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.4.2.0", 24)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.0"), 24)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -890,8 +906,10 @@ public void sourcePortMatch_exactMatchWinsOverEmptyList() throws Exception { ImmutableList.of(), ImmutableList.of(), ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.4.2.0", 24), - EnvoyServerProtoData.CidrRange.create("10.4.2.2", 31)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.0"), 24), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.2"), 31)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -908,7 +926,8 @@ public void sourcePortMatch_exactMatchWinsOverEmptyList() throws Exception { 0, ImmutableList.of(), ImmutableList.of(), - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.4.2.2", 31)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.2"), 31)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(7000, 15000), ImmutableList.of(), @@ -966,7 +985,8 @@ public void filterChain_5stepMatch() throws Exception { PORT, ImmutableList.of(), ImmutableList.of(), - ImmutableList.of(EnvoyServerProtoData.CidrRange.create(REMOTE_IP, 32)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString(REMOTE_IP), 32)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -981,9 +1001,11 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.1.2.0", 30)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 30)), ImmutableList.of(), - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.4.0.0", 16)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.0.0"), 16)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -997,8 +1019,10 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChainMatch.create( 0, ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("192.168.2.0", 24), - EnvoyServerProtoData.CidrRange.create("10.1.2.0", 30)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("192.168.2.0"), 24), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 30)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, @@ -1015,10 +1039,13 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChainMatch.create( 0, ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.1.0.0", 16), - EnvoyServerProtoData.CidrRange.create("10.1.2.0", 30)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.0.0"), 16), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 30)), ImmutableList.of(), - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.4.2.0", 24)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.0"), 24)), EnvoyServerProtoData.ConnectionSourceType.EXTERNAL, ImmutableList.of(16000, 9000), ImmutableList.of(), @@ -1034,12 +1061,16 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChainMatch.create( 0, ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.1.0.0", 16), - EnvoyServerProtoData.CidrRange.create("10.1.2.0", 30)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.0.0"), 16), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 30)), ImmutableList.of(), ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.4.2.0", 24), - EnvoyServerProtoData.CidrRange.create("192.168.2.0", 24)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.0"), 24), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("192.168.2.0"), 24)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(15000, 8000), ImmutableList.of(), @@ -1053,7 +1084,8 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChainMatch filterChainMatch6 = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.1.2.0", 29)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 29)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -1093,7 +1125,6 @@ public void filterChain_5stepMatch() throws Exception { } @Test - @SuppressWarnings("deprecation") public void filterChainMatch_unsupportedMatchers() throws Exception { EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "ROOTCA"); @@ -1105,8 +1136,8 @@ public void filterChainMatch_unsupportedMatchers() throws Exception { EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = EnvoyServerProtoData.FilterChainMatch.create( 0 /* destinationPort */, - ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.1.0.0", 16)) /* prefixRange */, + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.0.0"), 16)) /* prefixRange */, ImmutableList.of("managed-mtls", "h2") /* applicationProtocol */, ImmutableList.of() /* sourcePrefixRanges */, EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, @@ -1117,8 +1148,8 @@ public void filterChainMatch_unsupportedMatchers() throws Exception { EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = EnvoyServerProtoData.FilterChainMatch.create( 0 /* destinationPort */, - ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.0.0.0", 8)) /* prefixRange */, + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.0.0.0"), 8)) /* prefixRange */, ImmutableList.of() /* applicationProtocol */, ImmutableList.of() /* sourcePrefixRanges */, EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, @@ -1162,7 +1193,7 @@ public void filterChainMatch_unsupportedMatchers() throws Exception { assertThat(sslSet.get()).isEqualTo(defaultFilterChain.sslContextProviderSupplier()); assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext().getCommonTlsContext() - .getTlsCertificateCertificateProviderInstance() + .getTlsCertificateProviderInstance() .getCertificateName()).isEqualTo("CERT3"); } diff --git a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java new file mode 100644 index 00000000000..f252c6f4ec1 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java @@ -0,0 +1,524 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsNameResolver.CLUSTER_SELECTION_KEY; +import static io.grpc.xds.XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY; +import static io.grpc.xds.XdsTestUtils.CLUSTER_NAME; +import static io.grpc.xds.XdsTestUtils.EDS_NAME; +import static io.grpc.xds.XdsTestUtils.ENDPOINT_HOSTNAME; +import static io.grpc.xds.XdsTestUtils.ENDPOINT_PORT; +import static io.grpc.xds.XdsTestUtils.RDS_NAME; +import static io.grpc.xds.XdsTestUtils.buildRouteConfiguration; +import static io.grpc.xds.XdsTestUtils.getWrrLbConfigAsMap; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Any; +import com.google.protobuf.Empty; +import com.google.protobuf.Message; +import com.google.protobuf.UInt64Value; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; +import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig; +import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.TokenCacheConfig; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.Endpoints.LbEndpoint; +import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; +import io.grpc.xds.GcpAuthenticationFilter.FailingClientCall; +import io.grpc.xds.GcpAuthenticationFilter.GcpAuthenticationConfig; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsConfig.XdsClusterConfig; +import io.grpc.xds.XdsConfig.XdsClusterConfig.EndpointConfig; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.XdsListenerResource.LdsUpdate; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; +import io.grpc.xds.client.Locality; +import io.grpc.xds.client.XdsResourceType; +import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class GcpAuthenticationFilterTest { + private static final GcpAuthenticationFilter.Provider FILTER_PROVIDER = + new GcpAuthenticationFilter.Provider(); + private static final LdsUpdate ldsUpdate = getLdsUpdate(); + private static final EdsUpdate edsUpdate = getEdsUpdate(); + private static final RdsUpdate rdsUpdate = getRdsUpdate(); + private static final CdsUpdate cdsUpdate = getCdsUpdate(); + + @Before + public void setUp() { + System.setProperty("GRPC_EXPERIMENTAL_XDS_GCP_AUTHENTICATION_FILTER", "true"); + } + + @Test + public void testNewFilterInstancesPerFilterName() { + assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10)) + .isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10)); + } + + @Test + public void filterType_clientOnly() { + assertThat(FILTER_PROVIDER.isClientFilter()).isTrue(); + assertThat(FILTER_PROVIDER.isServerFilter()).isFalse(); + } + + @Test + public void testParseFilterConfig_withValidConfig() { + GcpAuthnFilterConfig config = GcpAuthnFilterConfig.newBuilder() + .setCacheConfig(TokenCacheConfig.newBuilder().setCacheSize(UInt64Value.of(20))) + .build(); + Any anyMessage = Any.pack(config); + + ConfigOrError result = FILTER_PROVIDER.parseFilterConfig(anyMessage); + + assertNotNull(result.config); + assertNull(result.errorDetail); + assertEquals(20L, result.config.getCacheSize()); + } + + @Test + public void testParseFilterConfig_withZeroCacheSize() { + GcpAuthnFilterConfig config = GcpAuthnFilterConfig.newBuilder() + .setCacheConfig(TokenCacheConfig.newBuilder().setCacheSize(UInt64Value.of(0))) + .build(); + Any anyMessage = Any.pack(config); + + ConfigOrError result = FILTER_PROVIDER.parseFilterConfig(anyMessage); + + assertNull(result.config); + assertNotNull(result.errorDetail); + assertTrue(result.errorDetail.contains("cache_config.cache_size must be greater than zero")); + } + + @Test + public void testParseFilterConfig_withInvalidMessageType() { + Message invalidMessage = Empty.getDefaultInstance(); + ConfigOrError result = + FILTER_PROVIDER.parseFilterConfig(invalidMessage); + + assertNull(result.config); + assertThat(result.errorDetail).contains("Invalid config type"); + } + + @Test + public void testClientInterceptor_success() throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); + + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + verify(mockChannel).newCall(eq(methodDescriptor), callOptionsCaptor.capture()); + CallOptions capturedOptions = callOptionsCaptor.getAllValues().get(0); + assertNotNull(capturedOptions.getCredentials()); + } + + @Test + public void testClientInterceptor_createsAndReusesCachedCredentials() + throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); + + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + verify(mockChannel, times(2)) + .newCall(eq(methodDescriptor), callOptionsCaptor.capture()); + CallOptions firstCapturedOptions = callOptionsCaptor.getAllValues().get(0); + CallOptions secondCapturedOptions = callOptionsCaptor.getAllValues().get(1); + assertNotNull(firstCapturedOptions.getCredentials()); + assertNotNull(secondCapturedOptions.getCredentials()); + assertSame(firstCapturedOptions.getCredentials(), secondCapturedOptions.getCredentials()); + } + + @Test + public void testClientInterceptor_withoutClusterSelectionKey() throws Exception { + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + CallOptions callOptionsWithXds = CallOptions.DEFAULT; + + ClientCall call = interceptor.interceptCall( + methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("does not contain cluster resource"); + } + + @Test + public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exception { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + Channel mockChannel = mock(Channel.class); + + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + verify(mockChannel).newCall(methodDescriptor, callOptionsWithXds); + } + + @Test + public void testClientInterceptor_xdsConfigDoesNotExist() throws Exception { + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0"); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("does not contain xds configuration"); + } + + @Test + public void testClientInterceptor_incorrectClusterName() throws Exception { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster("custer0", StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("does not contain xds cluster"); + } + + @Test + public void testClientInterceptor_statusOrError() throws Exception { + StatusOr errorCluster = + StatusOr.fromStatus(Status.NOT_FOUND.withDescription("Cluster resource not found")); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, errorCluster).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("Cluster resource not found"); + } + + @Test + public void testClientInterceptor_notAudienceWrapper() + throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + getCdsUpdateWithIncorrectAudienceWrapper(), + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("GCP Authn found wrong type"); + } + + @Test + public void testLruCacheAcrossInterceptors() throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2); + ClientInterceptor interceptor1 + = filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); + + interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + verify(mockChannel).newCall(eq(methodDescriptor), callOptionsCaptor.capture()); + CallOptions capturedOptions1 = callOptionsCaptor.getAllValues().get(0); + assertNotNull(capturedOptions1.getCredentials()); + ClientInterceptor interceptor2 + = filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + verify(mockChannel, times(2)) + .newCall(eq(methodDescriptor), callOptionsCaptor.capture()); + CallOptions capturedOptions2 = callOptionsCaptor.getAllValues().get(1); + assertNotNull(capturedOptions2.getCredentials()); + + assertSame(capturedOptions1.getCredentials(), capturedOptions2.getCredentials()); + } + + @Test + public void testLruCacheEvictionOnResize() throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + + ClientInterceptor interceptor1 = + filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null); + Channel mockChannel1 = Mockito.mock(Channel.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(CallOptions.class); + interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel1); + verify(mockChannel1).newCall(eq(methodDescriptor), captor.capture()); + CallOptions options1 = captor.getValue(); + // This will recreate the cache with max size of 1 and copy the credential for audience1. + ClientInterceptor interceptor2 = + filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + Channel mockChannel2 = Mockito.mock(Channel.class); + interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel2); + verify(mockChannel2).newCall(eq(methodDescriptor), captor.capture()); + CallOptions options2 = captor.getValue(); + + assertSame(options1.getCredentials(), options2.getCredentials()); + + clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, getCdsUpdate2(), new EndpointConfig(StatusOr.fromValue(edsUpdate))); + defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + + // This will evict the credential for audience1 and add new credential for audience2 + ClientInterceptor interceptor3 = + filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + Channel mockChannel3 = Mockito.mock(Channel.class); + interceptor3.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel3); + verify(mockChannel3).newCall(eq(methodDescriptor), captor.capture()); + CallOptions options3 = captor.getValue(); + + assertNotSame(options1.getCredentials(), options3.getCredentials()); + + clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate))); + defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + + // This will create new credential for audience1 because it has been evicted + ClientInterceptor interceptor4 = + filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + Channel mockChannel4 = Mockito.mock(Channel.class); + interceptor4.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel4); + verify(mockChannel4).newCall(eq(methodDescriptor), captor.capture()); + CallOptions options4 = captor.getValue(); + + assertNotSame(options1.getCredentials(), options4.getCredentials()); + } + + private static LdsUpdate getLdsUpdate() { + Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig( + "router", RouterFilter.ROUTER_CONFIG); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forRdsName( + 0L, RDS_NAME, Collections.singletonList(routerFilterConfig)); + return XdsListenerResource.LdsUpdate.forApiListener(httpConnectionManager); + } + + private static RdsUpdate getRdsUpdate() { + RouteConfiguration routeConfiguration = + buildRouteConfiguration("my-server", RDS_NAME, CLUSTER_NAME); + XdsResourceType.Args args = new XdsResourceType.Args( + XdsTestUtils.EMPTY_BOOTSTRAPPER_SERVER_INFO, "0", "0", null, null, null); + try { + return XdsRouteConfigureResource.getInstance().doParse(args, routeConfiguration); + } catch (ResourceInvalidException ex) { + return null; + } + } + + private static EdsUpdate getEdsUpdate() { + Map lbEndpointsMap = new HashMap<>(); + LbEndpoint lbEndpoint = LbEndpoint.create( + "127.0.0.5", ENDPOINT_PORT, 0, true, ENDPOINT_HOSTNAME, ImmutableMap.of()); + lbEndpointsMap.put( + Locality.create("", "", ""), + LocalityLbEndpoints.create(ImmutableList.of(lbEndpoint), 10, 0, ImmutableMap.of())); + return new XdsEndpointResource.EdsUpdate(EDS_NAME, lbEndpointsMap, Collections.emptyList()); + } + + private static CdsUpdate getCdsUpdate() { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("TEST_AUDIENCE")); + try { + CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false, null) + .lbPolicyConfig(getWrrLbConfigAsMap()); + return cdsUpdate.parsedMetadata(parsedMetadata.build()).build(); + } catch (IOException ex) { + return null; + } + } + + private static CdsUpdate getCdsUpdate2() { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("NEW_TEST_AUDIENCE")); + try { + CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false, null) + .lbPolicyConfig(getWrrLbConfigAsMap()); + return cdsUpdate.parsedMetadata(parsedMetadata.build()).build(); + } catch (IOException ex) { + return null; + } + } + + private static CdsUpdate getCdsUpdateWithIncorrectAudienceWrapper() throws IOException { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + parsedMetadata.put("FILTER_INSTANCE_NAME", "TEST_AUDIENCE"); + CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false, null) + .lbPolicyConfig(getWrrLbConfigAsMap()); + return cdsUpdate.parsedMetadata(parsedMetadata.build()).build(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java b/xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java index 30ea76b54f2..d4ee4159bc2 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verifyNoInteractions; @@ -27,11 +28,14 @@ import io.grpc.TlsChannelCredentials; import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil.GrpcBuildVersion; +import io.grpc.xds.client.AllowedGrpcServices; +import io.grpc.xds.client.AllowedGrpcServices.AllowedGrpcService; import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.Bootstrapper.AuthorityInfo; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.BootstrapperImpl; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.client.EnvoyProtoData.Node; import io.grpc.xds.client.Locality; import io.grpc.xds.client.XdsInitializationException; @@ -39,10 +43,9 @@ import java.util.List; import java.util.Map; import org.junit.After; +import org.junit.Assert; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -52,19 +55,18 @@ public class GrpcBootstrapperImplTest { private static final String BOOTSTRAP_FILE_PATH = "/fake/fs/path/bootstrap.json"; private static final String SERVER_URI = "trafficdirector.googleapis.com:443"; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); private final GrpcBootstrapperImpl bootstrapper = new GrpcBootstrapperImpl(); private String originalBootstrapPathFromEnvVar; private String originalBootstrapPathFromSysProp; private String originalBootstrapConfigFromEnvVar; private String originalBootstrapConfigFromSysProp; + private boolean originalExperimentalXdsFallbackFlag; @Before public void setUp() { saveEnvironment(); + originalExperimentalXdsFallbackFlag = CommonBootstrapperTestUtils.setEnableXdsFallback(true); bootstrapper.bootstrapPathFromEnvVar = BOOTSTRAP_FILE_PATH; } @@ -81,6 +83,73 @@ public void restoreEnvironment() { bootstrapper.bootstrapPathFromSysProp = originalBootstrapPathFromSysProp; bootstrapper.bootstrapConfigFromEnvVar = originalBootstrapConfigFromEnvVar; bootstrapper.bootstrapConfigFromSysProp = originalBootstrapConfigFromSysProp; + CommonBootstrapperTestUtils.setEnableXdsFallback(originalExperimentalXdsFallbackFlag); + } + + @Test + public void parseBootstrap_emptyServers_throws() { + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + XdsInitializationException e = Assert.assertThrows(XdsInitializationException.class, + bootstrapper::bootstrap); + assertThat(e).hasMessageThat().isEqualTo("Invalid bootstrap: 'xds_servers' is empty"); + } + + @Test + public void parseBootstrap_allowedGrpcServices() throws XdsInitializationException { + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [{\"type\": \"insecure\"}]\n" + + " }\n" + + " ],\n" + + " \"allowed_grpc_services\": {\n" + + " \"dns:///foo.com:443\": {\n" + + " \"channel_creds\": [{\"type\": \"insecure\"}],\n" + + " \"call_creds\": [{\"type\": \"access_token\"}]\n" + + " }\n" + + " }\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + GrpcBootstrapImplConfig customConfig = + (GrpcBootstrapImplConfig) info.implSpecificObject().get(); + AllowedGrpcServices allowed = customConfig.allowedGrpcServices(); + assertThat(allowed).isNotNull(); + assertThat(allowed.services()).containsKey("dns:///foo.com:443"); + AllowedGrpcService service = allowed.services().get("dns:///foo.com:443"); + assertThat(service.configuredChannelCredentials().channelCredentials()) + .isInstanceOf(InsecureChannelCredentials.class); + assertThat(service.callCredentials().isPresent()).isFalse(); + } + + @Test + public void parseBootstrap_allowedGrpcServices_invalidChannelCreds() { + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [{\"type\": \"insecure\"}]\n" + + " }\n" + + " ],\n" + + " \"allowed_grpc_services\": {\n" + + " \"dns:///foo.com:443\": {\n" + + " \"channel_creds\": []\n" + + " }\n" + + " }\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + XdsInitializationException e = assertThrows(XdsInitializationException.class, + bootstrapper::bootstrap); + assertThat(e).hasMessageThat() + .isEqualTo("Invalid bootstrap: server dns:///foo.com:443 'channel_creds' required"); } @Test @@ -232,7 +301,7 @@ public void parseBootstrap_IgnoreIrrelevantFields() throws XdsInitializationExce } @Test - public void parseBootstrap_missingServerChannelCreds() throws XdsInitializationException { + public void parseBootstrap_missingServerChannelCreds() { String rawData = "{\n" + " \"xds_servers\": [\n" + " {\n" @@ -242,13 +311,14 @@ public void parseBootstrap_missingServerChannelCreds() throws XdsInitializationE + "}"; bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); - thrown.expect(XdsInitializationException.class); - thrown.expectMessage("Invalid bootstrap: server " + SERVER_URI + " 'channel_creds' required"); - bootstrapper.bootstrap(); + XdsInitializationException e = Assert.assertThrows(XdsInitializationException.class, + bootstrapper::bootstrap); + assertThat(e).hasMessageThat() + .isEqualTo("Invalid bootstrap: server " + SERVER_URI + " 'channel_creds' required"); } @Test - public void parseBootstrap_unsupportedServerChannelCreds() throws XdsInitializationException { + public void parseBootstrap_unsupportedServerChannelCreds() { String rawData = "{\n" + " \"xds_servers\": [\n" + " {\n" @@ -261,9 +331,10 @@ public void parseBootstrap_unsupportedServerChannelCreds() throws XdsInitializat + "}"; bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); - thrown.expect(XdsInitializationException.class); - thrown.expectMessage("Server " + SERVER_URI + ": no supported channel credentials found"); - bootstrapper.bootstrap(); + XdsInitializationException e = assertThrows(XdsInitializationException.class, + bootstrapper::bootstrap); + assertThat(e).hasMessageThat() + .isEqualTo("Server " + SERVER_URI + ": no supported channel credentials found"); } @Test @@ -290,7 +361,7 @@ public void parseBootstrap_useFirstSupportedChannelCredentials() } @Test - public void parseBootstrap_noXdsServers() throws XdsInitializationException { + public void parseBootstrap_noXdsServers() { String rawData = "{\n" + " \"node\": {\n" + " \"id\": \"ENVOY_NODE_ID\",\n" @@ -308,9 +379,10 @@ public void parseBootstrap_noXdsServers() throws XdsInitializationException { + "}"; bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); - thrown.expect(XdsInitializationException.class); - thrown.expectMessage("Invalid bootstrap: 'xds_servers' does not exist."); - bootstrapper.bootstrap(); + XdsInitializationException e = assertThrows(XdsInitializationException.class, + bootstrapper::bootstrap); + assertThat(e).hasMessageThat() + .isEqualTo("Invalid bootstrap: 'xds_servers' does not exist."); } @Test @@ -339,15 +411,23 @@ public void parseBootstrap_serverWithoutServerUri() throws XdsInitializationExce + "}"; bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); - thrown.expectMessage("Invalid bootstrap: missing 'server_uri'"); - bootstrapper.bootstrap(); + XdsInitializationException e = assertThrows(XdsInitializationException.class, + bootstrapper::bootstrap); + assertThat(e).hasMessageThat().isEqualTo("Invalid bootstrap: missing 'server_uri'"); } @Test public void parseBootstrap_certProviderInstances() throws XdsInitializationException { String rawData = "{\n" - + " \"xds_servers\": [],\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ]\n" + + " }\n" + + " ],\n" + " \"certificate_providers\": {\n" + " \"gcp_id\": {\n" + " \"plugin_name\": \"meshca\",\n" @@ -384,7 +464,6 @@ public void parseBootstrap_certProviderInstances() throws XdsInitializationExcep bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); BootstrapInfo info = bootstrapper.bootstrap(); - assertThat(info.servers()).isEmpty(); assertThat(info.node()).isEqualTo(getNodeBuilder().build()); Map certProviders = info.certProviders(); assertThat(certProviders).isNotNull(); @@ -551,7 +630,14 @@ public void parseBootstrap_missingPluginName() { @Test public void parseBootstrap_grpcServerResourceId() throws XdsInitializationException { String rawData = "{\n" - + " \"xds_servers\": [],\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ]\n" + + " }\n" + + " ],\n" + " \"server_listener_resource_name_template\": \"grpc/serverx=%s\"\n" + "}"; @@ -627,6 +713,28 @@ public void serverFeatureIgnoreResourceDeletion() throws XdsInitializationExcept assertThat(serverInfo.ignoreResourceDeletion()).isTrue(); } + @Test + public void serverFeatureTrustedXdsServer() throws XdsInitializationException { + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ],\n" + + " \"server_features\": [\"trusted_xds_server\"]\n" + + " }\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); + assertThat(serverInfo.target()).isEqualTo(SERVER_URI); + assertThat(serverInfo.implSpecificConfig()).isInstanceOf(InsecureChannelCredentials.class); + assertThat(serverInfo.isTrustedXdsServer()).isTrue(); + } + @Test public void serverFeatureIgnoreResourceDeletion_xdsV3() throws XdsInitializationException { String rawData = "{\n" @@ -650,6 +758,72 @@ public void serverFeatureIgnoreResourceDeletion_xdsV3() throws XdsInitialization assertThat(serverInfo.ignoreResourceDeletion()).isTrue(); } + @Test + public void serverFeatures_ignoresUnknownValues() throws XdsInitializationException { + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ],\n" + + " \"server_features\": [null, {}, 3, true, \"unexpected\", \"trusted_xds_server\"]\n" + + " }\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); + assertThat(serverInfo.isTrustedXdsServer()).isTrue(); + } + + @Test + public void serverFeature_failOnDataErrors() throws XdsInitializationException { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ],\n" + + " \"server_features\": [\"fail_on_data_errors\"]\n" + + " }\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); + assertThat(serverInfo.target()).isEqualTo(SERVER_URI); + assertThat(serverInfo.implSpecificConfig()).isInstanceOf(InsecureChannelCredentials.class); + assertThat(serverInfo.failOnDataErrors()).isTrue(); + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + @Test + public void serverFeature_failOnDataErrors_requiresEnvVar() throws XdsInitializationException { + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ],\n" + + " \"server_features\": [\"fail_on_data_errors\"]\n" + + " }\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); + // Should be false when env var is not enabled + assertThat(serverInfo.failOnDataErrors()).isFalse(); + } + @Test public void notFound() { bootstrapper.bootstrapPathFromEnvVar = null; @@ -732,6 +906,12 @@ public void fallbackToConfigFromSysProp() throws XdsInitializationException { public void parseClientDefaultListenerResourceNameTemplate() throws Exception { String rawData = "{\n" + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ]\n" + + " }\n" + " ]\n" + "}"; bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); @@ -741,6 +921,12 @@ public void parseClientDefaultListenerResourceNameTemplate() throws Exception { rawData = "{\n" + " \"client_default_listener_resource_name_template\": \"xdstp://a.com/faketype/%s\",\n" + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ]\n" + + " }\n" + " ]\n" + "}"; bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); @@ -824,7 +1010,7 @@ public void parseAuthorities() throws Exception { } @Test - public void badFederationConfig() throws Exception { + public void badFederationConfig() { String rawData = "{\n" + " \"authorities\": {\n" + " \"a.com\": {\n" diff --git a/xds/src/test/java/io/grpc/xds/GrpcServiceConfigParserTest.java b/xds/src/test/java/io/grpc/xds/GrpcServiceConfigParserTest.java new file mode 100644 index 00000000000..ddfd0f19498 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/GrpcServiceConfigParserTest.java @@ -0,0 +1,757 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.Duration; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.insecure.v3.InsecureCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.local.v3.LocalCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.xds.v3.XdsCredentials; +import io.grpc.Attributes; +import io.grpc.CallCredentials; +import io.grpc.CompositeCallCredentials; +import io.grpc.CompositeChannelCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.SecurityLevel; +import io.grpc.Status; +import io.grpc.alts.GoogleDefaultChannelCredentials; +import io.grpc.xds.client.AllowedGrpcServices; +import io.grpc.xds.client.AllowedGrpcServices.AllowedGrpcService; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.ConfiguredChannelCredentials; +import io.grpc.xds.client.EnvoyProtoData.Node; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig; +import io.grpc.xds.internal.grpcservice.GrpcServiceParseException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class GrpcServiceConfigParserTest { + + private static final String CALL_CREDENTIALS_CLASS_NAME = + "io.grpc.xds.GrpcServiceConfigParser$SecurityAwareAccessTokenCredentials"; + + private static BootstrapInfo dummyBootstrapInfo() { + return dummyBootstrapInfo(Optional.empty()); + } + + private static BootstrapInfo dummyBootstrapInfo(Optional implSpecificObject) { + return BootstrapInfo.builder() + .servers(Collections + .singletonList(ServerInfo.create("test_target", Collections.emptyMap()))) + .node(Node.newBuilder().build()).implSpecificObject(implSpecificObject).build(); + } + + private static ServerInfo dummyServerInfo() { + return dummyServerInfo(true); + } + + private static ServerInfo dummyServerInfo(boolean isTrusted) { + return ServerInfo.create("test_target", Collections.emptyMap(), false, isTrusted, false, + false); + } + + private static GrpcServiceConfig parse( + GrpcService grpcServiceProto, BootstrapInfo bootstrapInfo, + ServerInfo serverInfo) + throws GrpcServiceParseException { + return GrpcServiceConfigParser.parse(grpcServiceProto, bootstrapInfo, serverInfo); + } + + private static GrpcServiceConfig.GoogleGrpcConfig parseGoogleGrpcConfig( + GrpcService.GoogleGrpc googleGrpcProto, BootstrapInfo bootstrapInfo, + ServerInfo serverInfo) + throws GrpcServiceParseException { + return GrpcServiceConfigParser.parseGoogleGrpcConfig( + googleGrpcProto, bootstrapInfo, serverInfo); + } + + @Test + public void parse_success() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + HeaderValue asciiHeader = + HeaderValue.newBuilder().setKey("test_key").setValue("test_value").build(); + HeaderValue binaryHeader = + HeaderValue.newBuilder().setKey("test_key-bin").setRawValue(ByteString + .copyFrom("test_value_binary".getBytes(StandardCharsets.UTF_8))).build(); + Duration timeout = Duration.newBuilder().setSeconds(10).build(); + GrpcService grpcService = + GrpcService.newBuilder().setGoogleGrpc(googleGrpc).addInitialMetadata(asciiHeader) + .addInitialMetadata(binaryHeader).setTimeout(timeout).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + // Assert target URI + assertThat(config.googleGrpc().target()).isEqualTo("test_uri"); + + // Assert channel credentials + assertThat(config.googleGrpc().configuredChannelCredentials().channelCredentials()) + .isInstanceOf(InsecureChannelCredentials.class); + GrpcServiceConfigParser.ProtoChannelCredsConfig credsConfig = + (GrpcServiceConfigParser.ProtoChannelCredsConfig) + config.googleGrpc().configuredChannelCredentials().channelCredsConfig(); + assertThat(credsConfig.configProto()).isEqualTo(insecureCreds); + + // Assert call credentials + assertThat(config.googleGrpc().callCredentials().isPresent()).isTrue(); + assertThat(config.googleGrpc().callCredentials().get().getClass().getName()) + .isEqualTo(CALL_CREDENTIALS_CLASS_NAME); + + // Assert initial metadata + assertThat(config.initialMetadata()).isNotEmpty(); + assertThat(config.initialMetadata().get(0).key()).isEqualTo("test_key"); + assertThat(config.initialMetadata().get(0).value().get()).isEqualTo("test_value"); + assertThat(config.initialMetadata().get(1).key()).isEqualTo("test_key-bin"); + assertThat(config.initialMetadata().get(1).rawValue().get().toByteArray()) + .isEqualTo("test_value_binary".getBytes(StandardCharsets.UTF_8)); + + // Assert timeout + assertThat(config.timeout().isPresent()).isTrue(); + assertThat(config.timeout().get()).isEqualTo(java.time.Duration.ofSeconds(10)); + } + + @Test + public void parse_minimalSuccess_defaults() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.googleGrpc().target()).isEqualTo("test_uri"); + assertThat(config.initialMetadata()).isEmpty(); + assertThat(config.timeout().isPresent()).isFalse(); + } + + @Test + public void parse_missingGoogleGrpc() { + GrpcService grpcService = GrpcService.newBuilder().build(); + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .startsWith("Unsupported: GrpcService must have GoogleGrpc, got: "); + } + + @Test + public void parse_emptyCallCredentials() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.googleGrpc().callCredentials().isPresent()).isFalse(); + } + + @Test + public void parse_emptyChannelCredentials() { + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addCallCredentialsPlugin(accessTokenCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .isEqualTo("No valid supported channel_credentials found"); + } + + @Test + public void parse_googleDefaultCredentials() throws GrpcServiceParseException { + Any googleDefaultCreds = Any.pack(GoogleDefaultCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(googleDefaultCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.googleGrpc().configuredChannelCredentials().channelCredentials()) + .isInstanceOf(CompositeChannelCredentials.class); + GrpcServiceConfigParser.ProtoChannelCredsConfig credsConfig = + (GrpcServiceConfigParser.ProtoChannelCredsConfig) + config.googleGrpc().configuredChannelCredentials().channelCredsConfig(); + assertThat(credsConfig.configProto()).isEqualTo(googleDefaultCreds); + } + + @Test + public void parse_localCredentials() throws GrpcServiceParseException { + Any localCreds = Any.pack(LocalCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(localCreds).addCallCredentialsPlugin(accessTokenCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("LocalCredentials are not supported in grpc-java"); + } + + @Test + public void parse_xdsCredentials_withInsecureFallback() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + XdsCredentials xdsCreds = + XdsCredentials.newBuilder().setFallbackCredentials(insecureCreds).build(); + Any xdsCredsAny = Any.pack(xdsCreds); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(xdsCredsAny).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = GrpcServiceConfigParser.parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.googleGrpc().configuredChannelCredentials().channelCredentials()) + .isNotNull(); + GrpcServiceConfigParser.ProtoChannelCredsConfig credsConfig = + (GrpcServiceConfigParser.ProtoChannelCredsConfig) + config.googleGrpc().configuredChannelCredentials().channelCredsConfig(); + assertThat(credsConfig.configProto()).isEqualTo(xdsCredsAny); + } + + @Test + public void parse_tlsCredentials_notSupported() { + Any tlsCreds = Any + .pack(io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.tls.v3.TlsCredentials + .getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(tlsCreds).addCallCredentialsPlugin(accessTokenCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("TlsCredentials input stream construction pending"); + } + + @Test + public void parse_invalidChannelCredentialsProto() { + // Pack a Duration proto, but try to unpack it as GoogleDefaultCredentials + Any invalidCreds = Any.pack(Duration.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(invalidCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat().contains("No valid supported channel_credentials found"); + } + + @Test + public void parse_ignoredUnsupportedCallCredentialsProto() throws GrpcServiceParseException { + // Pack a Duration proto, but try to unpack it as AccessTokenCredentials + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any invalidCallCredentials = Any.pack(Duration.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(invalidCallCredentials) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + assertThat(config.googleGrpc().callCredentials().isPresent()).isFalse(); + } + + @Test + public void parse_invalidAccessTokenCallCredentialsProto() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any invalidCallCredentials = Any.pack(AccessTokenCredentials.newBuilder().setToken("").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(invalidCallCredentials) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("Missing or empty access token in call credentials"); + } + + @Test + public void parse_multipleCallCredentials() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds1 = + Any.pack(AccessTokenCredentials.newBuilder().setToken("token1").build()); + Any accessTokenCreds2 = + Any.pack(AccessTokenCredentials.newBuilder().setToken("token2").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(accessTokenCreds1) + .addCallCredentialsPlugin(accessTokenCreds2).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.googleGrpc().callCredentials().isPresent()).isTrue(); + assertThat(config.googleGrpc().callCredentials().get()) + .isInstanceOf(CompositeCallCredentials.class); + } + + @Test + public void parse_untrustedControlPlane_withoutOverride() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + BootstrapInfo untrustedBootstrapInfo = dummyBootstrapInfo(Optional.empty()); + ServerInfo untrustedServerInfo = + dummyServerInfo(false); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse( + grpcService, untrustedBootstrapInfo, untrustedServerInfo)); + assertThat(exception).hasMessageThat() + .contains("Untrusted xDS server & URI not found in allowed_grpc_services"); + } + + @Test + public void parse_untrustedControlPlane_withOverride() throws GrpcServiceParseException { + // The proto credentials (insecure) should be ignored in favor of the override (google default) + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + ConfiguredChannelCredentials overrideChannelCreds = ConfiguredChannelCredentials.create( + GoogleDefaultChannelCredentials.create(), + new GrpcServiceConfigParser.ProtoChannelCredsConfig( + GrpcServiceConfigParser.GOOGLE_DEFAULT_CREDENTIALS_TYPE_URL, + Any.pack(GoogleDefaultCredentials.getDefaultInstance()))); + AllowedGrpcService override = AllowedGrpcService.builder() + .configuredChannelCredentials(overrideChannelCreds).build(); + AllowedGrpcServices servicesMap = + AllowedGrpcServices.create( + ImmutableMap.of("test_uri", override)); + + BootstrapInfo untrustedBootstrapInfo = + dummyBootstrapInfo(Optional.of(GrpcBootstrapImplConfig.create(servicesMap))); + ServerInfo untrustedServerInfo = + dummyServerInfo(false); + + GrpcServiceConfig config = + parse(grpcService, untrustedBootstrapInfo, untrustedServerInfo); + + // Assert channel credentials are the override, not the proto's insecure creds + assertThat(config.googleGrpc().configuredChannelCredentials().channelCredentials()) + .isInstanceOf(CompositeChannelCredentials.class); + } + + @Test + public void parse_invalidTimeout() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + + // Negative timeout + Duration timeout = Duration.newBuilder().setSeconds(-10).build(); + GrpcService grpcService = GrpcService.newBuilder() + .setGoogleGrpc(googleGrpc).setTimeout(timeout).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("Timeout must be strictly positive"); + + // Zero timeout + timeout = Duration.newBuilder().setSeconds(0).setNanos(0).build(); + GrpcService grpcServiceZero = GrpcService.newBuilder() + .setGoogleGrpc(googleGrpc).setTimeout(timeout).build(); + + exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcServiceZero, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("Timeout must be strictly positive"); + } + + @Test + public void parseGoogleGrpcConfig_unsupportedScheme() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("unknown://test") + .addChannelCredentialsPlugin(insecureCreds).build(); + + BootstrapInfo bootstrapInfo = dummyBootstrapInfo(); + ServerInfo serverInfo = dummyServerInfo(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parseGoogleGrpcConfig( + googleGrpc, bootstrapInfo, serverInfo)); + assertThat(exception).hasMessageThat() + .contains("Target URI scheme is not resolvable"); + } + + @Test + public void parse_disallowedInitialMetadata() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + HeaderValue disallowedHeader = + HeaderValue.newBuilder().setKey("host").setValue("test_value").build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc) + .addInitialMetadata(disallowedHeader).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, dummyBootstrapInfo(), dummyServerInfo())); + assertThat(exception).hasMessageThat().contains("Invalid initial metadata header: host"); + } + + @Test + public void parse_invalidDuration() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + + Duration timeout = Duration.newBuilder().setSeconds(10).setNanos(1_000_000_000).build(); + GrpcService grpcService = GrpcService.newBuilder() + .setGoogleGrpc(googleGrpc).setTimeout(timeout).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, dummyBootstrapInfo(), dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("Timeout must be strictly positive and valid"); + } + + @Test + public void parse_invalidChannelCredsProto() { + Any invalidCreds = Any.newBuilder() + .setTypeUrl(GrpcServiceConfigParser.XDS_CREDENTIALS_TYPE_URL) + .setValue(ByteString.copyFrom(new byte[]{1, 2, 3})).build(); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(invalidCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, dummyBootstrapInfo(), dummyServerInfo())); + assertThat(exception).hasMessageThat().contains("Failed to parse channel credentials"); + } + + @Test + public void parse_unsupportedXdsFallbackCreds() { + Any unsupportedFallback = Any.pack(Duration.getDefaultInstance()); + XdsCredentials xds = + XdsCredentials.newBuilder().setFallbackCredentials(unsupportedFallback).build(); + Any xdsCredsAny = Any.newBuilder() + .setTypeUrl(GrpcServiceConfigParser.XDS_CREDENTIALS_TYPE_URL) + .setValue(xds.toByteString()).build(); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(xdsCredsAny).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, dummyBootstrapInfo(), dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("Unsupported fallback credentials type for XdsCredentials"); + } + + @Test + public void parse_invalidCallCredsProto() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + // We just create an Any representing AccessTokenCredentials but with invalid bytes + Any invalidCallCreds = Any.newBuilder() + .setTypeUrl(Any.pack(AccessTokenCredentials.getDefaultInstance()).getTypeUrl()) + .setValue(ByteString.copyFrom(new byte[]{1, 2, 3})).build(); + + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(invalidCallCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, dummyBootstrapInfo(), dummyServerInfo())); + assertThat(exception).hasMessageThat().contains("Failed to parse access token credentials"); + } + + @Test + public void parseGoogleGrpcConfig_malformedUriThrows() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri(":::::") + .addChannelCredentialsPlugin(insecureCreds).build(); + + BootstrapInfo bootstrapInfo = dummyBootstrapInfo(); + ServerInfo serverInfo = dummyServerInfo(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parseGoogleGrpcConfig(googleGrpc, bootstrapInfo, serverInfo)); + assertThat(exception).hasMessageThat().contains("Target URI scheme is not resolvable"); + } + + @Test + public void parseGoogleGrpcConfig_untrustedWithCallCredentialsOverride() throws Exception { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + + ConfiguredChannelCredentials overrideChannelCreds = + ConfiguredChannelCredentials.create(GoogleDefaultChannelCredentials.create(), + new GrpcServiceConfigParser.ProtoChannelCredsConfig( + GrpcServiceConfigParser.GOOGLE_DEFAULT_CREDENTIALS_TYPE_URL, + Any.pack(GoogleDefaultCredentials.getDefaultInstance()))); + + CallCredentials fakeCallCreds = Mockito.mock(CallCredentials.class); + AllowedGrpcService override = AllowedGrpcService.builder() + .configuredChannelCredentials(overrideChannelCreds).callCredentials(fakeCallCreds).build(); + + AllowedGrpcServices servicesMap = + AllowedGrpcServices + .create(ImmutableMap.of("test_uri", override)); + + BootstrapInfo untrustedBootstrapInfo = + dummyBootstrapInfo(Optional.of(GrpcBootstrapImplConfig.create(servicesMap))); + ServerInfo untrustedServerInfo = dummyServerInfo(false); + + GrpcServiceConfig.GoogleGrpcConfig config = + parseGoogleGrpcConfig(googleGrpc, untrustedBootstrapInfo, untrustedServerInfo); + + assertThat(config.callCredentials().isPresent()).isTrue(); + assertThat(config.callCredentials().get()).isSameInstanceAs(fakeCallCreds); + } + + @Test + public void protoChannelCredsConfig_equalsAndHashCode() { + Any insecureCreds1 = Any.pack(InsecureCredentials.getDefaultInstance()); + Any insecureCreds2 = Any.pack(InsecureCredentials.getDefaultInstance()); + Any localCreds = Any.pack(LocalCredentials.getDefaultInstance()); + + GrpcServiceConfigParser.ProtoChannelCredsConfig config1 = + new GrpcServiceConfigParser.ProtoChannelCredsConfig("type1", insecureCreds1); + GrpcServiceConfigParser.ProtoChannelCredsConfig config1Equivalent = + new GrpcServiceConfigParser.ProtoChannelCredsConfig("type1", insecureCreds2); + GrpcServiceConfigParser.ProtoChannelCredsConfig configDifferentType = + new GrpcServiceConfigParser.ProtoChannelCredsConfig("type2", insecureCreds1); + GrpcServiceConfigParser.ProtoChannelCredsConfig configDifferentProto = + new GrpcServiceConfigParser.ProtoChannelCredsConfig("type1", localCreds); + + assertThat(config1.type()).isEqualTo("type1"); + assertThat(config1.equals(config1)).isTrue(); + assertThat(config1.equals(null)).isFalse(); + assertThat(config1.equals(new Object())).isFalse(); + assertThat(config1.equals(config1Equivalent)).isTrue(); + assertThat(config1.hashCode()).isEqualTo(config1Equivalent.hashCode()); + assertThat(config1.equals(configDifferentType)).isFalse(); + assertThat(config1.equals(configDifferentProto)).isFalse(); + } + + static class RecordingMetadataApplier extends CallCredentials.MetadataApplier { + boolean applied = false; + boolean failed = false; + Metadata appliedHeaders = null; + + @Override + public void apply(Metadata headers) { + applied = true; + appliedHeaders = headers; + } + + @Override + public void fail(Status status) { + failed = true; + } + } + + static class FakeRequestInfo extends CallCredentials.RequestInfo { + private final SecurityLevel securityLevel; + private final MethodDescriptor methodDescriptor; + + FakeRequestInfo(SecurityLevel securityLevel) { + this.securityLevel = securityLevel; + this.methodDescriptor = MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("test_service/test_method") + .setRequestMarshaller(new NoopMarshaller()) + .setResponseMarshaller(new NoopMarshaller()) + .build(); + } + + private static class NoopMarshaller implements MethodDescriptor.Marshaller { + @Override + public InputStream stream(T value) { + return null; + } + + @Override + public T parse(InputStream stream) { + return null; + } + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return methodDescriptor; + } + + @Override + public SecurityLevel getSecurityLevel() { + return securityLevel; + } + + @Override + public String getAuthority() { + return "dummy-authority"; + } + + @Override + public Attributes getTransportAttrs() { + return Attributes.EMPTY; + } + } + + + @Test + public void securityAwareCredentials_secureConnection_appliesToken() throws Exception { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds) + .addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + CallCredentials creds = config.googleGrpc().callCredentials().get(); + RecordingMetadataApplier applier = new RecordingMetadataApplier(); + CountDownLatch latch = new CountDownLatch(1); + + creds.applyRequestMetadata( + new FakeRequestInfo(SecurityLevel.PRIVACY_AND_INTEGRITY), + Runnable::run, // Use direct executor to avoid async issues in test + new CallCredentials.MetadataApplier() { + @Override + public void apply(Metadata headers) { + applier.apply(headers); + latch.countDown(); + } + + @Override + public void fail(Status status) { + applier.fail(status); + latch.countDown(); + } + }); + + latch.await(5, TimeUnit.SECONDS); + assertThat(applier.applied).isTrue(); + assertThat(applier.appliedHeaders.get( + Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("Bearer test_token"); + } + + @Test + public void securityAwareCredentials_insecureConnection_appliesEmptyMetadata() throws Exception { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds) + .addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + CallCredentials creds = config.googleGrpc().callCredentials().get(); + RecordingMetadataApplier applier = new RecordingMetadataApplier(); + + creds.applyRequestMetadata( + new FakeRequestInfo(SecurityLevel.NONE), + Runnable::run, + applier); + + assertThat(applier.applied).isTrue(); + assertThat(applier.appliedHeaders.get( + Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER))) + .isNull(); + } + + +} diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java index 3c159ba7055..a1b1adae17f 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java @@ -19,9 +19,11 @@ import static com.google.common.truth.Truth.assertThat; import static io.envoyproxy.envoy.config.route.v3.RouteAction.ClusterSpecifierCase.CLUSTER_SPECIFIER_PLUGIN; import static io.grpc.xds.XdsEndpointResource.GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import com.github.udpa.udpa.type.v1.TypedStruct; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; @@ -50,6 +52,7 @@ import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.config.core.v3.HttpProtocolOptions; import io.envoyproxy.envoy.config.core.v3.Locality; +import io.envoyproxy.envoy.config.core.v3.Metadata; import io.envoyproxy.envoy.config.core.v3.PathConfigSource; import io.envoyproxy.envoy.config.core.v3.RuntimeFractionalPercent; import io.envoyproxy.envoy.config.core.v3.SelfConfigSource; @@ -84,6 +87,7 @@ import io.envoyproxy.envoy.extensions.filters.common.fault.v3.FaultDelay; import io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort; import io.envoyproxy.envoy.extensions.filters.http.fault.v3.HTTPFault; +import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.Audience; import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBACPerRoute; import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; @@ -91,10 +95,10 @@ import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; import io.envoyproxy.envoy.extensions.load_balancing_policies.client_side_weighted_round_robin.v3.ClientSideWeightedRoundRobin; import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; +import io.envoyproxy.envoy.extensions.transport_sockets.http_11_proxy.v3.Http11ProxyUpstreamTransport; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.SdsSecretConfig; @@ -108,10 +112,8 @@ import io.envoyproxy.envoy.type.v3.FractionalPercent; import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; import io.envoyproxy.envoy.type.v3.Int64Range; -import io.grpc.ClientInterceptor; import io.grpc.EquivalentAddressGroup; import io.grpc.InsecureChannelCredentials; -import io.grpc.LoadBalancer; import io.grpc.LoadBalancerRegistry; import io.grpc.Status.Code; import io.grpc.internal.JsonUtil; @@ -127,6 +129,8 @@ import io.grpc.xds.Endpoints.LbEndpoint; import io.grpc.xds.Endpoints.LocalityLbEndpoints; import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; +import io.grpc.xds.MetadataRegistry.MetadataValueParser; import io.grpc.xds.RouteLookupServiceClusterSpecifierPlugin.RlsPluginConfig; import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.VirtualHost.Route.RouteAction; @@ -136,11 +140,12 @@ import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.client.BackendMetricPropagation; import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.LoadStatsManager2; import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsResourceType; import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; -import io.grpc.xds.client.XdsResourceType.StructOrError; import io.grpc.xds.internal.Matchers; import io.grpc.xds.internal.Matchers.FractionMatcher; import io.grpc.xds.internal.Matchers.HeaderMatcher; @@ -149,14 +154,10 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -import javax.annotation.Nullable; import org.junit.After; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -164,27 +165,32 @@ @RunWith(JUnit4.class) public class GrpcXdsClientImplDataTest { + private static final FaultFilter.Provider FAULT_FILTER_PROVIDER = new FaultFilter.Provider(); + private static final RbacFilter.Provider RBAC_FILTER_PROVIDER = new RbacFilter.Provider(); + private static final RouterFilter.Provider ROUTER_FILTER_PROVIDER = new RouterFilter.Provider(); + private static final ServerInfo LRS_SERVER_INFO = ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); + private static final String GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE = + "GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE"; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); private final FilterRegistry filterRegistry = FilterRegistry.getDefaultRegistry(); private boolean originalEnableRouteLookup; private boolean originalEnableLeastRequest; + private boolean originalEnableUseSystemRootCerts; @Before public void setUp() { originalEnableRouteLookup = XdsRouteConfigureResource.enableRouteLookup; originalEnableLeastRequest = XdsClusterResource.enableLeastRequest; - assertThat(originalEnableLeastRequest).isFalse(); + originalEnableUseSystemRootCerts = XdsClusterResource.enableSystemRootCerts; } @After public void tearDown() { XdsRouteConfigureResource.enableRouteLookup = originalEnableRouteLookup; XdsClusterResource.enableLeastRequest = originalEnableLeastRequest; + XdsClusterResource.enableSystemRootCerts = originalEnableUseSystemRootCerts; } @Test @@ -200,7 +206,7 @@ public void parseRoute_withRouteAction() { .setCluster("cluster-foo")) .build(); StructOrError struct = XdsRouteConfigureResource.parseRoute( - proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of()); + proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()) .isEqualTo( @@ -208,7 +214,7 @@ public void parseRoute_withRouteAction() { RouteMatch.create(PathMatcher.fromPath("/service/method", false), Collections.emptyList(), null), RouteAction.forCluster( - "cluster-foo", Collections.emptyList(), null, null), + "cluster-foo", Collections.emptyList(), null, null, false), ImmutableMap.of())); } @@ -223,7 +229,7 @@ public void parseRoute_withNonForwardingAction() { .setNonForwardingAction(NonForwardingAction.getDefaultInstance()) .build(); StructOrError struct = XdsRouteConfigureResource.parseRoute( - proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of()); + proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct()) .isEqualTo( Route.forNonForwardingAction( @@ -242,7 +248,8 @@ public void parseRoute_withUnsupportedActionTypes() { .setRedirect(RedirectAction.getDefaultInstance()) .build(); res = XdsRouteConfigureResource.parseRoute( - redirectRoute, filterRegistry, ImmutableMap.of(), ImmutableSet.of()); + redirectRoute, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), + getXdsResourceTypeArgs(true)); assertThat(res.getStruct()).isNull(); assertThat(res.getErrorDetail()) .isEqualTo("Route [route-blade] with unknown action type: REDIRECT"); @@ -254,7 +261,8 @@ public void parseRoute_withUnsupportedActionTypes() { .setDirectResponse(DirectResponseAction.getDefaultInstance()) .build(); res = XdsRouteConfigureResource.parseRoute( - directResponseRoute, filterRegistry, ImmutableMap.of(), ImmutableSet.of()); + directResponseRoute, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), + getXdsResourceTypeArgs(true)); assertThat(res.getStruct()).isNull(); assertThat(res.getErrorDetail()) .isEqualTo("Route [route-blade] with unknown action type: DIRECT_RESPONSE"); @@ -266,7 +274,8 @@ public void parseRoute_withUnsupportedActionTypes() { .setFilterAction(FilterAction.getDefaultInstance()) .build(); res = XdsRouteConfigureResource.parseRoute( - filterRoute, filterRegistry, ImmutableMap.of(), ImmutableSet.of()); + filterRoute, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), + getXdsResourceTypeArgs(true)); assertThat(res.getStruct()).isNull(); assertThat(res.getErrorDetail()) .isEqualTo("Route [route-blade] with unknown action type: FILTER_ACTION"); @@ -288,7 +297,8 @@ public void parseRoute_skipRouteWithUnsupportedMatcher() { .setCluster("cluster-foo")) .build(); assertThat(XdsRouteConfigureResource.parseRoute( - proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of())) + proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), + getXdsResourceTypeArgs(true))) .isNull(); } @@ -305,7 +315,8 @@ public void parseRoute_skipRouteWithUnsupportedAction() { .setClusterHeader("cluster header")) // cluster_header action not supported .build(); assertThat(XdsRouteConfigureResource.parseRoute( - proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of())) + proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), + getXdsResourceTypeArgs(true))) .isNull(); } @@ -515,10 +526,48 @@ public void parseRouteAction_withCluster() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); + assertThat(struct.getErrorDetail()).isNull(); + assertThat(struct.getStruct().cluster()).isEqualTo("cluster-foo"); + assertThat(struct.getStruct().weightedClusters()).isNull(); + assertThat(struct.getStruct().autoHostRewrite()).isFalse(); + } + + @Test + public void parseRouteAction_withCluster_autoHostRewriteEnabled() { + System.setProperty(GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE, "true"); + try { + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setCluster("cluster-foo") + .setAutoHostRewrite(BoolValue.of(true)) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); + assertThat(struct.getErrorDetail()).isNull(); + assertThat(struct.getStruct().cluster()).isEqualTo("cluster-foo"); + assertThat(struct.getStruct().weightedClusters()).isNull(); + assertThat(struct.getStruct().autoHostRewrite()).isTrue(); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE); + } + } + + @Test + public void parseRouteAction_withCluster_flagDisabled_autoHostRewriteNotEnabled() { + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setCluster("cluster-foo") + .setAutoHostRewrite(BoolValue.of(true)) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct().cluster()).isEqualTo("cluster-foo"); assertThat(struct.getStruct().weightedClusters()).isNull(); + assertThat(struct.getStruct().autoHostRewrite()).isTrue(); } @Test @@ -539,12 +588,74 @@ public void parseRouteAction_withWeightedCluster() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); + assertThat(struct.getErrorDetail()).isNull(); + assertThat(struct.getStruct().cluster()).isNull(); + assertThat(struct.getStruct().weightedClusters()).containsExactly( + ClusterWeight.create("cluster-foo", 30, ImmutableMap.of()), + ClusterWeight.create("cluster-bar", 70, ImmutableMap.of())); + assertThat(struct.getStruct().autoHostRewrite()).isFalse(); + } + + @Test + public void parseRouteAction_withWeightedCluster_autoHostRewriteEnabled() { + System.setProperty(GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE, "true"); + try { + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setWeightedClusters( + WeightedCluster.newBuilder() + .addClusters( + WeightedCluster.ClusterWeight + .newBuilder() + .setName("cluster-foo") + .setWeight(UInt32Value.newBuilder().setValue(30))) + .addClusters(WeightedCluster.ClusterWeight + .newBuilder() + .setName("cluster-bar") + .setWeight(UInt32Value.newBuilder().setValue(70)))) + .setAutoHostRewrite(BoolValue.of(true)) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); + assertThat(struct.getErrorDetail()).isNull(); + assertThat(struct.getStruct().cluster()).isNull(); + assertThat(struct.getStruct().weightedClusters()).containsExactly( + ClusterWeight.create("cluster-foo", 30, ImmutableMap.of()), + ClusterWeight.create("cluster-bar", 70, ImmutableMap.of())); + assertThat(struct.getStruct().autoHostRewrite()).isTrue(); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE); + } + } + + @Test + public void parseRouteAction_withWeightedCluster_flagDisabled_autoHostRewriteDisabled() { + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setWeightedClusters( + WeightedCluster.newBuilder() + .addClusters( + WeightedCluster.ClusterWeight + .newBuilder() + .setName("cluster-foo") + .setWeight(UInt32Value.newBuilder().setValue(30))) + .addClusters(WeightedCluster.ClusterWeight + .newBuilder() + .setName("cluster-bar") + .setWeight(UInt32Value.newBuilder().setValue(70)))) + .setAutoHostRewrite(BoolValue.of(true)) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct().cluster()).isNull(); assertThat(struct.getStruct().weightedClusters()).containsExactly( ClusterWeight.create("cluster-foo", 30, ImmutableMap.of()), ClusterWeight.create("cluster-bar", 70, ImmutableMap.of())); + assertThat(struct.getStruct().autoHostRewrite()).isTrue(); } @Test @@ -565,7 +676,7 @@ public void parseRouteAction_weightedClusterSum() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()).isEqualTo("Sum of cluster weights should be above 0."); } @@ -581,7 +692,7 @@ public void parseRouteAction_withTimeoutByGrpcTimeoutHeaderMax() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().timeoutNano()).isEqualTo(TimeUnit.SECONDS.toNanos(5L)); } @@ -596,7 +707,7 @@ public void parseRouteAction_withTimeoutByMaxStreamDuration() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().timeoutNano()).isEqualTo(TimeUnit.SECONDS.toNanos(5L)); } @@ -608,7 +719,7 @@ public void parseRouteAction_withTimeoutUnset() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().timeoutNano()).isNull(); } @@ -630,7 +741,7 @@ public void parseRouteAction_withRetryPolicy() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); RouteAction.RetryPolicy retryPolicy = struct.getStruct().retryPolicy(); assertThat(retryPolicy.maxAttempts()).isEqualTo(4); assertThat(retryPolicy.initialBackoff()).isEqualTo(Durations.fromMillis(500)); @@ -654,7 +765,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder.build()) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().retryPolicy()).isNotNull(); assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()).isEmpty(); @@ -667,7 +778,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()).isEqualTo("No base_interval specified in retry_backoff"); // max_interval unset @@ -677,7 +788,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); retryPolicy = struct.getStruct().retryPolicy(); assertThat(retryPolicy.maxBackoff()).isEqualTo(Durations.fromMillis(500 * 10)); @@ -688,7 +799,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()) .isEqualTo("base_interval in retry_backoff must be positive"); @@ -701,7 +812,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()) .isEqualTo("max_interval in retry_backoff cannot be less than base_interval"); @@ -714,7 +825,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()) .isEqualTo("max_interval in retry_backoff cannot be less than base_interval"); @@ -727,7 +838,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().retryPolicy().initialBackoff()) .isEqualTo(Durations.fromMillis(1)); assertThat(struct.getStruct().retryPolicy().maxBackoff()) @@ -743,7 +854,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); retryPolicy = struct.getStruct().retryPolicy(); assertThat(retryPolicy.initialBackoff()).isEqualTo(Durations.fromMillis(25)); assertThat(retryPolicy.maxBackoff()).isEqualTo(Durations.fromMillis(250)); @@ -762,7 +873,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) .containsExactly(Code.CANCELLED); @@ -780,7 +891,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) .containsExactly(Code.CANCELLED); @@ -798,7 +909,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) .containsExactly(Code.CANCELLED); } @@ -829,7 +940,7 @@ public void parseRouteAction_withHashPolicies() { io.envoyproxy.envoy.config.route.v3.RouteAction.HashPolicy.newBuilder() .setFilterState( FilterState.newBuilder() - .setKey(XdsResourceType.HASH_POLICY_FILTER_STATE_KEY))) + .setKey(XdsRouteConfigureResource.HASH_POLICY_FILTER_STATE_KEY))) .addHashPolicy( io.envoyproxy.envoy.config.route.v3.RouteAction.HashPolicy.newBuilder() .setQueryParameter( @@ -837,7 +948,7 @@ public void parseRouteAction_withHashPolicies() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); List policies = struct.getStruct().hashPolicies(); assertThat(policies).hasSize(2); assertThat(policies.get(0).type()).isEqualTo(HashPolicy.Type.HEADER); @@ -857,7 +968,7 @@ public void parseRouteAction_custerSpecifierNotSet() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct).isNull(); } @@ -870,10 +981,65 @@ public void parseRouteAction_clusterSpecifier_routeLookupDisabled() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct).isNull(); } + @Test + public void parseRouteAction_clusterSpecifier() { + XdsRouteConfigureResource.enableRouteLookup = true; + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setClusterSpecifierPlugin(CLUSTER_SPECIFIER_PLUGIN.name()) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(CLUSTER_SPECIFIER_PLUGIN.name(), RlsPluginConfig.create( + ImmutableMap.of("lookupService", "rls-cbt.googleapis.com"))), ImmutableSet.of(), + getXdsResourceTypeArgs(true)); + assertThat(struct.getStruct()).isNotNull(); + assertThat(struct.getStruct().autoHostRewrite()).isFalse(); + } + + @Test + public void parseRouteAction_clusterSpecifier_autoHostRewriteEnabled() { + System.setProperty(GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE, "true"); + try { + XdsRouteConfigureResource.enableRouteLookup = true; + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setClusterSpecifierPlugin(CLUSTER_SPECIFIER_PLUGIN.name()) + .setAutoHostRewrite(BoolValue.of(true)) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(CLUSTER_SPECIFIER_PLUGIN.name(), RlsPluginConfig.create( + ImmutableMap.of("lookupService", "rls-cbt.googleapis.com"))), ImmutableSet.of(), + getXdsResourceTypeArgs(true)); + assertThat(struct.getStruct()).isNotNull(); + assertThat(struct.getStruct().autoHostRewrite()).isTrue(); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE); + } + } + + @Test + public void parseRouteAction_clusterSpecifier_flagDisabled_autoHostRewriteDisabled() { + XdsRouteConfigureResource.enableRouteLookup = true; + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setClusterSpecifierPlugin(CLUSTER_SPECIFIER_PLUGIN.name()) + .setAutoHostRewrite(BoolValue.of(true)) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(CLUSTER_SPECIFIER_PLUGIN.name(), RlsPluginConfig.create( + ImmutableMap.of("lookupService", "rls-cbt.googleapis.com"))), ImmutableSet.of(), + getXdsResourceTypeArgs(true)); + assertThat(struct.getStruct()).isNotNull(); + assertThat(struct.getStruct().autoHostRewrite()).isTrue(); + } + @Test public void parseClusterWeight() { io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight proto = @@ -888,7 +1054,7 @@ public void parseClusterWeight() { } @Test - public void parseLocalityLbEndpoints_withHealthyEndpoints() { + public void parseLocalityLbEndpoints_withHealthyEndpoints() throws ResourceInvalidException { io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() .setLocality(Locality.newBuilder() @@ -908,11 +1074,38 @@ public void parseLocalityLbEndpoints_withHealthyEndpoints() { assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()).isEqualTo( LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create("172.14.14.5", 8888, 20, true)), 100, 1)); + Collections.singletonList(LbEndpoint.create("172.14.14.5", 8888, + 20, true, "", ImmutableMap.of())), + 100, 1, ImmutableMap.of())); + } + + @Test + public void parseLocalityLbEndpoints_onlyPermitIp() { + io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = + io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() + .setLocality(Locality.newBuilder() + .setRegion("region-foo").setZone("zone-foo").setSubZone("subZone-foo")) + .setLoadBalancingWeight(UInt32Value.newBuilder().setValue(100)) // locality weight + .setPriority(1) + .addLbEndpoints(io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress( + SocketAddress.newBuilder() + .setAddress("example.com").setPortValue(8888)))) + .setHealthStatus(io.envoyproxy.envoy.config.core.v3.HealthStatus.HEALTHY) + .setLoadBalancingWeight(UInt32Value.newBuilder().setValue(20))) // endpoint weight + .build(); + ResourceInvalidException ex = assertThrows( + ResourceInvalidException.class, + () -> XdsEndpointResource.parseLocalityLbEndpoints(proto)); + assertThat(ex.getMessage()).contains("IP"); + assertThat(ex.getMessage()).contains("example.com"); } @Test - public void parseLocalityLbEndpoints_treatUnknownHealthAsHealthy() { + public void parseLocalityLbEndpoints_treatUnknownHealthAsHealthy() + throws ResourceInvalidException { io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() .setLocality(Locality.newBuilder() @@ -932,11 +1125,13 @@ public void parseLocalityLbEndpoints_treatUnknownHealthAsHealthy() { assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()).isEqualTo( LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create("172.14.14.5", 8888, 20, true)), 100, 1)); + Collections.singletonList(LbEndpoint.create("172.14.14.5", 8888, + 20, true, "", ImmutableMap.of())), + 100, 1, ImmutableMap.of())); } @Test - public void parseLocalityLbEndpoints_withUnHealthyEndpoints() { + public void parseLocalityLbEndpoints_withUnHealthyEndpoints() throws ResourceInvalidException { io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() .setLocality(Locality.newBuilder() @@ -956,11 +1151,13 @@ public void parseLocalityLbEndpoints_withUnHealthyEndpoints() { assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()).isEqualTo( LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create("172.14.14.5", 8888, 20, false)), 100, 1)); + Collections.singletonList(LbEndpoint.create("172.14.14.5", 8888, 20, + false, "", ImmutableMap.of())), + 100, 1, ImmutableMap.of())); } @Test - public void parseLocalityLbEndpoints_ignorZeroWeightLocality() { + public void parseLocalityLbEndpoints_ignorZeroWeightLocality() throws ResourceInvalidException { io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() .setLocality(Locality.newBuilder() @@ -1017,7 +1214,10 @@ public void parseLocalityLbEndpoints_withDualStackEndpoints() { EquivalentAddressGroup expectedEag = new EquivalentAddressGroup(socketAddressList); assertThat(struct.getStruct()).isEqualTo( LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(expectedEag, 20, true)), 100, 1)); + Collections.singletonList(LbEndpoint.create( + expectedEag, 20, true, "", ImmutableMap.of())), 100, 1, ImmutableMap.of())); + } catch (ResourceInvalidException e) { + throw new RuntimeException(e); } finally { if (originalDualStackProp != null) { System.setProperty(GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS, originalDualStackProp); @@ -1028,7 +1228,7 @@ public void parseLocalityLbEndpoints_withDualStackEndpoints() { } @Test - public void parseLocalityLbEndpoints_invalidPriority() { + public void parseLocalityLbEndpoints_invalidPriority() throws ResourceInvalidException { io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() .setLocality(Locality.newBuilder() @@ -1074,37 +1274,39 @@ public String typeUrl() { } } - private static class TestFilter implements io.grpc.xds.Filter, - io.grpc.xds.Filter.ClientInterceptorBuilder { - @Override - public String[] typeUrls() { - return new String[]{"test-url"}; - } + private static class TestFilter implements io.grpc.xds.Filter { - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - return ConfigOrError.fromConfig(new SimpleFilterConfig(rawProtoMessage)); - } + static final class Provider implements io.grpc.xds.Filter.Provider { + @Override + public String[] typeUrls() { + return new String[]{"test-url"}; + } - @Override - public ConfigOrError parseFilterConfigOverride( - Message rawProtoMessage) { - return ConfigOrError.fromConfig(new SimpleFilterConfig(rawProtoMessage)); - } + @Override + public boolean isClientFilter() { + return true; + } - @Nullable - @Override - public ClientInterceptor buildClientInterceptor(FilterConfig config, - @Nullable FilterConfig overrideConfig, - LoadBalancer.PickSubchannelArgs args, - ScheduledExecutorService scheduler) { - return null; + @Override + public TestFilter newInstance(String name) { + return new TestFilter(); + } + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + return ConfigOrError.fromConfig(new SimpleFilterConfig(rawProtoMessage)); + } + + @Override + public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { + return ConfigOrError.fromConfig(new SimpleFilterConfig(rawProtoMessage)); + } } } @Test public void parseHttpFilter_typedStructMigration() { - filterRegistry.register(new TestFilter()); + filterRegistry.register(new TestFilter.Provider()); Struct rawStruct = Struct.newBuilder() .putFields("name", Value.newBuilder().setStringValue("default").build()) .build(); @@ -1133,7 +1335,7 @@ public void parseHttpFilter_typedStructMigration() { @Test public void parseOverrideHttpFilter_typedStructMigration() { - filterRegistry.register(new TestFilter()); + filterRegistry.register(new TestFilter.Provider()); Struct rawStruct0 = Struct.newBuilder() .putFields("name", Value.newBuilder().setStringValue("default0").build()) .build(); @@ -1174,7 +1376,7 @@ public void parseHttpFilter_unsupportedAndRequired() { @Test public void parseHttpFilter_routerFilterForClient() { - filterRegistry.register(RouterFilter.INSTANCE); + filterRegistry.register(ROUTER_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1188,7 +1390,7 @@ public void parseHttpFilter_routerFilterForClient() { @Test public void parseHttpFilter_routerFilterForServer() { - filterRegistry.register(RouterFilter.INSTANCE); + filterRegistry.register(ROUTER_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1202,7 +1404,7 @@ public void parseHttpFilter_routerFilterForServer() { @Test public void parseHttpFilter_faultConfigForClient() { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1229,7 +1431,7 @@ public void parseHttpFilter_faultConfigForClient() { @Test public void parseHttpFilter_faultConfigUnsupportedForServer() { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1258,7 +1460,7 @@ public void parseHttpFilter_faultConfigUnsupportedForServer() { @Test public void parseHttpFilter_rbacConfigForServer() { - filterRegistry.register(RbacFilter.INSTANCE); + filterRegistry.register(RBAC_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1285,7 +1487,7 @@ public void parseHttpFilter_rbacConfigForServer() { @Test public void parseHttpFilter_rbacConfigUnsupportedForClient() { - filterRegistry.register(RbacFilter.INSTANCE); + filterRegistry.register(RBAC_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1314,7 +1516,7 @@ public void parseHttpFilter_rbacConfigUnsupportedForClient() { @Test public void parseOverrideRbacFilterConfig() { - filterRegistry.register(RbacFilter.INSTANCE); + filterRegistry.register(RBAC_FILTER_PROVIDER); RBACPerRoute rbacPerRoute = RBACPerRoute.newBuilder() .setRbac( @@ -1340,7 +1542,7 @@ public void parseOverrideRbacFilterConfig() { @Test public void parseOverrideFilterConfigs_unsupportedButOptional() { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HTTPFault httpFault = HTTPFault.newBuilder() .setDelay(FaultDelay.newBuilder().setFixedDelay(Durations.fromNanos(3000))) .build(); @@ -1360,7 +1562,7 @@ public void parseOverrideFilterConfigs_unsupportedButOptional() { @Test public void parseOverrideFilterConfigs_unsupportedAndRequired() { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HTTPFault httpFault = HTTPFault.newBuilder() .setDelay(FaultDelay.newBuilder().setFixedDelay(Durations.fromNanos(3000))) .build(); @@ -1392,11 +1594,12 @@ public void parseHttpConnectionManager_xffNumTrustedHopsUnsupported() throws ResourceInvalidException { @SuppressWarnings("deprecation") HttpConnectionManager hcm = HttpConnectionManager.newBuilder().setXffNumTrustedHops(2).build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("HttpConnectionManager with xff_num_trusted_hops unsupported"); - XdsListenerResource.parseHttpConnectionManager( - hcm, filterRegistry, - true /* does not matter */); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, + () -> XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("HttpConnectionManager with xff_num_trusted_hops unsupported"); } @Test @@ -1406,12 +1609,13 @@ public void parseHttpConnectionManager_OriginalIpDetectionExtensionsMustEmpty() HttpConnectionManager hcm = HttpConnectionManager.newBuilder() .addOriginalIpDetectionExtensions(TypedExtensionConfig.newBuilder().build()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("HttpConnectionManager with original_ip_detection_extensions unsupported"); - XdsListenerResource.parseHttpConnectionManager( - hcm, filterRegistry, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, false, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("HttpConnectionManager with original_ip_detection_extensions unsupported"); } - + @Test public void parseHttpConnectionManager_missingRdsAndInlinedRouteConfiguration() throws ResourceInvalidException { @@ -1424,11 +1628,12 @@ public void parseHttpConnectionManager_missingRdsAndInlinedRouteConfiguration() HttpFilter.newBuilder().setName("terminal").setTypedConfig( Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("HttpConnectionManager neither has inlined route_config nor RDS"); - XdsListenerResource.parseHttpConnectionManager( - hcm, filterRegistry, - true /* does not matter */); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("HttpConnectionManager neither has inlined route_config nor RDS"); } @Test @@ -1443,16 +1648,17 @@ public void parseHttpConnectionManager_duplicateHttpFilters() throws ResourceInv HttpFilter.newBuilder().setName("terminal").setTypedConfig( Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("HttpConnectionManager contains duplicate HttpFilter: envoy.filter.foo"); - XdsListenerResource.parseHttpConnectionManager( - hcm, filterRegistry, - true /* does not matter */); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("HttpConnectionManager contains duplicate HttpFilter: envoy.filter.foo"); } @Test public void parseHttpConnectionManager_lastNotTerminal() throws ResourceInvalidException { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HttpConnectionManager hcm = HttpConnectionManager.newBuilder() .addHttpFilters( @@ -1461,16 +1667,17 @@ public void parseHttpConnectionManager_lastNotTerminal() throws ResourceInvalidE HttpFilter.newBuilder().setName("envoy.filter.bar").setIsOptional(true) .setTypedConfig(Any.pack(HTTPFault.newBuilder().build()))) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("The last HttpFilter must be a terminal filter: envoy.filter.bar"); - XdsListenerResource.parseHttpConnectionManager( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( hcm, filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("The last HttpFilter must be a terminal filter: envoy.filter.bar"); } @Test public void parseHttpConnectionManager_terminalNotLast() throws ResourceInvalidException { - filterRegistry.register(RouterFilter.INSTANCE); + filterRegistry.register(ROUTER_FILTER_PROVIDER); HttpConnectionManager hcm = HttpConnectionManager.newBuilder() .addHttpFilters( @@ -1479,11 +1686,12 @@ public void parseHttpConnectionManager_terminalNotLast() throws ResourceInvalidE .addHttpFilters( HttpFilter.newBuilder().setName("envoy.filter.foo").setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("A terminal HttpFilter must be the last filter: terminal"); - XdsListenerResource.parseHttpConnectionManager( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( hcm, filterRegistry, - true); + true, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("A terminal HttpFilter must be the last filter: terminal"); } @Test @@ -1495,11 +1703,12 @@ public void parseHttpConnectionManager_unknownFilters() throws ResourceInvalidEx .addHttpFilters( HttpFilter.newBuilder().setName("envoy.filter.bar").setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("The last HttpFilter must be a terminal filter: envoy.filter.bar"); - XdsListenerResource.parseHttpConnectionManager( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( hcm, filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("The last HttpFilter must be a terminal filter: envoy.filter.bar"); } @Test @@ -1507,11 +1716,12 @@ public void parseHttpConnectionManager_emptyFilters() throws ResourceInvalidExce HttpConnectionManager hcm = HttpConnectionManager.newBuilder() .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Missing HttpFilter in HttpConnectionManager."); - XdsListenerResource.parseHttpConnectionManager( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( hcm, filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("Missing HttpFilter in HttpConnectionManager."); } @Test @@ -1560,7 +1770,7 @@ public void parseHttpConnectionManager_clusterSpecifierPlugin() throws Exception io.grpc.xds.HttpConnectionManager parsedHcm = XdsListenerResource.parseHttpConnectionManager( hcm, filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true)); VirtualHost virtualHost = Iterables.getOnlyElement(parsedHcm.virtualHosts()); Route parsedRoute = Iterables.getOnlyElement(virtualHost.routes()); @@ -1635,12 +1845,12 @@ public void parseHttpConnectionManager_duplicatePluginName() throws Exception { Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Multiple ClusterSpecifierPlugins with the same name: rls-plugin-1"); - - XdsListenerResource.parseHttpConnectionManager( - hcm, filterRegistry, - true /* does not matter */); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("Multiple ClusterSpecifierPlugins with the same name: rls-plugin-1"); } @Test @@ -1687,12 +1897,12 @@ public void parseHttpConnectionManager_pluginNameNotFound() throws Exception { Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("ClusterSpecifierPlugin for [invalid-plugin-name] not found"); - - XdsListenerResource.parseHttpConnectionManager( - hcm, filterRegistry, - true /* does not matter */); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .contains("ClusterSpecifierPlugin for [invalid-plugin-name] not found"); } @@ -1766,7 +1976,7 @@ public void parseHttpConnectionManager_optionalPlugin() throws ResourceInvalidEx HttpFilter.newBuilder().setName("terminal").setTypedConfig( Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(), filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true)); // Verify that the only route left is the one with the registered RLS plugin `rls-plugin-1`, // while the route with unregistered optional `optional-plugin-`1 has been skipped. @@ -1794,7 +2004,7 @@ public void parseHttpConnectionManager_validateRdsConfigSource() throws Exceptio .build(); XdsListenerResource.parseHttpConnectionManager( hcm1, filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true)); HttpConnectionManager hcm2 = HttpConnectionManager.newBuilder() @@ -1808,7 +2018,7 @@ public void parseHttpConnectionManager_validateRdsConfigSource() throws Exceptio .build(); XdsListenerResource.parseHttpConnectionManager( hcm2, filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true)); HttpConnectionManager hcm3 = HttpConnectionManager.newBuilder() @@ -1821,12 +2031,12 @@ public void parseHttpConnectionManager_validateRdsConfigSource() throws Exceptio HttpFilter.newBuilder().setName("terminal").setTypedConfig( Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( + hcm3, filterRegistry, + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo( "HttpConnectionManager contains invalid RDS: must specify ADS or self ConfigSource"); - XdsListenerResource.parseHttpConnectionManager( - hcm3, filterRegistry, - true /* does not matter */); } @Test @@ -1916,11 +2126,10 @@ public void parseClusterSpecifierPlugin_unregisteredPlugin() throws Exception { .setTypedConfig(Any.pack(StringValue.of("unregistered")))) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsRouteConfigureResource.parseClusterSpecifierPlugin(pluginProto, registry)); + assertThat(e).hasMessageThat().isEqualTo( "Unsupported ClusterSpecifierPlugin type: type.googleapis.com/google.protobuf.StringValue"); - - XdsRouteConfigureResource.parseClusterSpecifierPlugin(pluginProto, registry); } @Test @@ -2117,11 +2326,11 @@ public void parseCluster_transportSocketMatches_exception() throws ResourceInval Cluster.TransportSocketMatch.newBuilder().setName("match1").build()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.processCluster(cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry())); + assertThat(e).hasMessageThat().isEqualTo( "Cluster cluster-foo.googleapis.com: transport-socket-matches not supported."); - XdsClusterResource.processCluster(cluster, null, LRS_SERVER_INFO, - LoadBalancerRegistry.getDefaultRegistry()); } @Test @@ -2166,12 +2375,303 @@ public void parseCluster_validateEdsSourceConfig() throws ResourceInvalidExcepti .setLbPolicy(LbPolicy.ROUND_ROBIN) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.processCluster(cluster3, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry())); + assertThat(e).hasMessageThat().isEqualTo( "Cluster cluster-foo.googleapis.com: field eds_cluster_config must be set to indicate to" + " use EDS over ADS or self ConfigSource"); - XdsClusterResource.processCluster(cluster3, null, LRS_SERVER_INFO, + } + + @Test + public void processCluster_parsesMetadata() + throws ResourceInvalidException, InvalidProtocolBufferException { + MetadataRegistry metadataRegistry = MetadataRegistry.getInstance(); + + MetadataValueParser testParser = + new MetadataValueParser() { + @Override + public String getTypeUrl() { + return "type.googleapis.com/test.Type"; + } + + @Override + public Object parse(Any value) { + assertThat(value.getValue().toStringUtf8()).isEqualTo("test"); + return value.getValue().toStringUtf8() + "_processed"; + } + }; + metadataRegistry.registerParser(testParser); + + Any typedFilterMetadata = Any.newBuilder() + .setTypeUrl("type.googleapis.com/test.Type") + .setValue(ByteString.copyFromUtf8("test")) + .build(); + + Struct filterMetadata = Struct.newBuilder() + .putFields("key1", Value.newBuilder().setStringValue("value1").build()) + .putFields("key2", Value.newBuilder().setNumberValue(42).build()) + .build(); + + Metadata metadata = Metadata.newBuilder() + .putTypedFilterMetadata("TYPED_FILTER_METADATA", typedFilterMetadata) + .putFilterMetadata("FILTER_METADATA", filterMetadata) + .build(); + + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .setMetadata(metadata) + .build(); + + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); + + ImmutableMap expectedParsedMetadata = ImmutableMap.of( + "TYPED_FILTER_METADATA", "test_processed", + "FILTER_METADATA", ImmutableMap.of( + "key1", "value1", + "key2", 42.0)); + assertThat(update.parsedMetadata()).isEqualTo(expectedParsedMetadata); + metadataRegistry.removeParser(testParser); + } + + @Test + public void processCluster_parsesAudienceMetadata() throws Exception { + MetadataRegistry.getInstance(); + + Audience audience = Audience.newBuilder() + .setUrl("https://example.com") + .build(); + + Any audienceMetadata = Any.newBuilder() + .setTypeUrl("type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.Audience") + .setValue(audience.toByteString()) + .build(); + + Struct filterMetadata = Struct.newBuilder() + .putFields("key1", Value.newBuilder().setStringValue("value1").build()) + .putFields("key2", Value.newBuilder().setNumberValue(42).build()) + .build(); + + Metadata metadata = Metadata.newBuilder() + .putTypedFilterMetadata("AUDIENCE_METADATA", audienceMetadata) + .putFilterMetadata("FILTER_METADATA", filterMetadata) + .build(); + + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .setMetadata(metadata) + .build(); + + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); + + ImmutableMap expectedParsedMetadata = ImmutableMap.of( + "AUDIENCE_METADATA", "https://example.com", + "FILTER_METADATA", ImmutableMap.of( + "key1", "value1", + "key2", 42.0)); + + assertThat(update.parsedMetadata().get("FILTER_METADATA")) + .isEqualTo(expectedParsedMetadata.get("FILTER_METADATA")); + assertThat(update.parsedMetadata().get("AUDIENCE_METADATA")) + .isInstanceOf(AudienceWrapper.class); + } + + @Test + public void processCluster_parsesAddressMetadata() throws Exception { + + // Create an Address message + Address address = Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("192.168.1.1") + .setPortValue(8080) + .build()) + .build(); + + // Wrap the Address in Any + Any addressMetadata = Any.newBuilder() + .setTypeUrl("type.googleapis.com/envoy.config.core.v3.Address") + .setValue(address.toByteString()) + .build(); + + Struct filterMetadata = Struct.newBuilder() + .putFields("key1", Value.newBuilder().setStringValue("value1").build()) + .putFields("key2", Value.newBuilder().setNumberValue(42).build()) + .build(); + + Metadata metadata = Metadata.newBuilder() + .putTypedFilterMetadata("ADDRESS_METADATA", addressMetadata) + .putFilterMetadata("FILTER_METADATA", filterMetadata) + .build(); + + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .setMetadata(metadata) + .build(); + + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, LoadBalancerRegistry.getDefaultRegistry()); + + ImmutableMap expectedParsedMetadata = ImmutableMap.of( + "ADDRESS_METADATA", new InetSocketAddress("192.168.1.1", 8080), + "FILTER_METADATA", ImmutableMap.of( + "key1", "value1", + "key2", 42.0)); + + assertThat(update.parsedMetadata()).isEqualTo(expectedParsedMetadata); + } + + @Test + public void processCluster_metadataKeyCollision_resolvesToTypedMetadata() throws Exception { + MetadataRegistry metadataRegistry = MetadataRegistry.getInstance(); + + MetadataValueParser testParser = + new MetadataValueParser() { + @Override + public String getTypeUrl() { + return "type.googleapis.com/test.Type"; + } + + @Override + public Object parse(Any value) { + return "typedMetadataValue"; + } + }; + metadataRegistry.registerParser(testParser); + + Any typedFilterMetadata = Any.newBuilder() + .setTypeUrl("type.googleapis.com/test.Type") + .setValue(ByteString.copyFromUtf8("test")) + .build(); + + Struct filterMetadata = Struct.newBuilder() + .putFields("key1", Value.newBuilder().setStringValue("filterMetadataValue").build()) + .build(); + + Metadata metadata = Metadata.newBuilder() + .putTypedFilterMetadata("key1", typedFilterMetadata) + .putFilterMetadata("key1", filterMetadata) + .build(); + + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .setMetadata(metadata) + .build(); + + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); + + ImmutableMap expectedParsedMetadata = ImmutableMap.of( + "key1", "typedMetadataValue"); + assertThat(update.parsedMetadata()).isEqualTo(expectedParsedMetadata); + metadataRegistry.removeParser(testParser); + } + + @Test + public void parseNonAggregateCluster_withHttp11ProxyTransportSocket() throws Exception { + XdsClusterResource.isEnabledXdsHttpConnect = true; + + Http11ProxyUpstreamTransport http11ProxyUpstreamTransport = + Http11ProxyUpstreamTransport.newBuilder() + .setTransportSocket(TransportSocket.getDefaultInstance()) + .build(); + + TransportSocket transportSocket = TransportSocket.newBuilder() + .setName("envoy.transport_sockets.http_11_proxy") + .setTypedConfig(Any.pack(http11ProxyUpstreamTransport)) + .build(); + + Cluster cluster = Cluster.newBuilder() + .setName("cluster-http11-proxy.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder().setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-http11-proxy.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .setTransportSocket(transportSocket) + .build(); + + CdsUpdate result = + XdsClusterResource.processCluster(cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); + + assertThat(result).isNotNull(); + assertThat(result.isHttp11ProxyAvailable()).isTrue(); + } + + @Test + public void processCluster_parsesOrcaLrsPropagationMetrics() throws ResourceInvalidException { + LoadStatsManager2.isEnabledOrcaLrsPropagation = true; + + ImmutableList metricSpecs = ImmutableList.of( + "cpu_utilization", + "named_metrics.foo", + "unknown_metric_spec" + ); + Cluster cluster = Cluster.newBuilder() + .setName("cluster-orca.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder().setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-orca.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .addAllLrsReportEndpointMetrics(metricSpecs) + .build(); + + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, LoadBalancerRegistry.getDefaultRegistry()); + + BackendMetricPropagation propagationConfig = update.backendMetricPropagation(); + assertThat(propagationConfig).isNotNull(); + assertThat(propagationConfig.propagateCpuUtilization).isTrue(); + assertThat(propagationConfig.propagateMemUtilization).isFalse(); + assertThat(propagationConfig.shouldPropagateNamedMetric("foo")).isTrue(); + assertThat(propagationConfig.shouldPropagateNamedMetric("bar")).isFalse(); + assertThat(propagationConfig.shouldPropagateNamedMetric("unknown_metric_spec")) + .isFalse(); + + LoadStatsManager2.isEnabledOrcaLrsPropagation = false; } @Test @@ -2181,10 +2681,11 @@ public void parseServerSideListener_invalidTrafficDirection() throws ResourceInv .setName("listener1") .setTrafficDirection(TrafficDirection.OUTBOUND) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Listener listener1 with invalid traffic direction: OUTBOUND"); - XdsListenerResource.parseServerSideListener( - listener, null, filterRegistry, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("Listener listener1 with invalid traffic direction: OUTBOUND"); } @Test @@ -2194,7 +2695,7 @@ public void parseServerSideListener_noTrafficDirection() throws ResourceInvalidE .setName("listener1") .build(); XdsListenerResource.parseServerSideListener( - listener, null, filterRegistry, null); + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true)); } @Test @@ -2205,10 +2706,11 @@ public void parseServerSideListener_listenerFiltersPresent() throws ResourceInva .setTrafficDirection(TrafficDirection.INBOUND) .addListenerFilters(ListenerFilter.newBuilder().build()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Listener listener1 cannot have listener_filters"); - XdsListenerResource.parseServerSideListener( - listener, null, filterRegistry, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener(listener, null, filterRegistry, null, + getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("Listener listener1 cannot have listener_filters"); } @Test @@ -2219,10 +2721,44 @@ public void parseServerSideListener_useOriginalDst() throws ResourceInvalidExcep .setTrafficDirection(TrafficDirection.INBOUND) .setUseOriginalDst(BoolValue.of(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Listener listener1 cannot have use_original_dst set to true"); - XdsListenerResource.parseServerSideListener( - listener,null, filterRegistry, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener(listener, null, filterRegistry, null, + getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("Listener listener1 cannot have use_original_dst set to true"); + } + + @Test + public void parseServerSideListener_emptyAddress() throws ResourceInvalidException { + Listener listener = + Listener.newBuilder() + .setName("listener1") + .setTrafficDirection(TrafficDirection.INBOUND) + .setAddress(Address.newBuilder() + .setSocketAddress( + SocketAddress.newBuilder())) + .build(); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo("Invalid address: Empty address is not allowed."); + } + + @Test + public void parseServerSideListener_namedPort() throws ResourceInvalidException { + Listener listener = + Listener.newBuilder() + .setName("listener1") + .setTrafficDirection(TrafficDirection.INBOUND) + .setAddress(Address.newBuilder() + .setSocketAddress( + SocketAddress.newBuilder() + .setAddress("172.14.14.5").setNamedPort(""))) + .build(); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo("NAMED_PORT is not supported in gRPC."); } @Test @@ -2268,10 +2804,11 @@ public void parseServerSideListener_nonUniqueFilterChainMatch() throws ResourceI .setTrafficDirection(TrafficDirection.INBOUND) .addAllFilterChains(Arrays.asList(filterChain1, filterChain2)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("FilterChainMatch must be unique. Found duplicate:"); - XdsListenerResource.parseServerSideListener( - listener, null, filterRegistry, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .startsWith("FilterChainMatch must be unique. Found duplicate:"); } @Test @@ -2317,10 +2854,11 @@ public void parseServerSideListener_nonUniqueFilterChainMatch_sameFilter() .setTrafficDirection(TrafficDirection.INBOUND) .addAllFilterChains(Arrays.asList(filterChain1, filterChain2)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("FilterChainMatch must be unique. Found duplicate:"); - XdsListenerResource.parseServerSideListener( - listener,null, filterRegistry, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .startsWith("FilterChainMatch must be unique. Found duplicate:"); } @Test @@ -2369,7 +2907,7 @@ public void parseServerSideListener_uniqueFilterChainMatch() throws ResourceInva .addAllFilterChains(Arrays.asList(filterChain1, filterChain2)) .build(); XdsListenerResource.parseServerSideListener( - listener, null, filterRegistry, null); + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true)); } @Test @@ -2380,11 +2918,12 @@ public void parseFilterChain_noHcm() throws ResourceInvalidException { .setFilterChainMatch(FilterChainMatch.getDefaultInstance()) .setTransportSocket(TransportSocket.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseFilterChain( + filterChain, "filter-chain-foo", null, filterRegistry, null, null, + getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo( "FilterChain filter-chain-foo should contain exact one HttpConnectionManager filter"); - XdsListenerResource.parseFilterChain( - filterChain, null, filterRegistry, null, null); } @Test @@ -2398,11 +2937,12 @@ public void parseFilterChain_duplicateFilter() throws ResourceInvalidException { .setTransportSocket(TransportSocket.getDefaultInstance()) .addAllFilters(Arrays.asList(filter, filter)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseFilterChain( + filterChain, "filter-chain-foo", null, filterRegistry, null, null, + getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo( "FilterChain filter-chain-foo should contain exact one HttpConnectionManager filter"); - XdsListenerResource.parseFilterChain( - filterChain, null, filterRegistry, null, null); } @Test @@ -2415,12 +2955,13 @@ public void parseFilterChain_filterMissingTypedConfig() throws ResourceInvalidEx .setTransportSocket(TransportSocket.getDefaultInstance()) .addFilters(filter) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseFilterChain( + filterChain, "filter-chain-foo", null, filterRegistry, null, null, + getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo( "FilterChain filter-chain-foo contains filter envoy.http_connection_manager " + "without typed_config"); - XdsListenerResource.parseFilterChain( - filterChain, null, filterRegistry, null, null); } @Test @@ -2437,17 +2978,18 @@ public void parseFilterChain_unsupportedFilter() throws ResourceInvalidException .setTransportSocket(TransportSocket.getDefaultInstance()) .addFilters(filter) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseFilterChain( + filterChain, "filter-chain-foo", null, filterRegistry, null, null, + getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo( "FilterChain filter-chain-foo contains filter unsupported with unsupported " + "typed_config type unsupported-type-url"); - XdsListenerResource.parseFilterChain( - filterChain, null, filterRegistry, null, null); } @Test public void parseFilterChain_noName() throws ResourceInvalidException { - FilterChain filterChain1 = + FilterChain filterChain0 = FilterChain.newBuilder() .setFilterChainMatch(FilterChainMatch.getDefaultInstance()) .addFilters(buildHttpConnectionManagerFilter( @@ -2457,9 +2999,53 @@ public void parseFilterChain_noName() throws ResourceInvalidException { .setTypedConfig(Any.pack(Router.newBuilder().build())) .build())) .build(); - FilterChain filterChain2 = + + FilterChain filterChain1 = + FilterChain.newBuilder() + .setFilterChainMatch( + FilterChainMatch.newBuilder().addAllSourcePorts(Arrays.asList(443, 8080))) + .addFilters(buildHttpConnectionManagerFilter( + HttpFilter.newBuilder() + .setName("http-filter-bar") + .setTypedConfig(Any.pack(Router.newBuilder().build())) + .setIsOptional(true) + .build())) + .build(); + + Listener listenerProto = + Listener.newBuilder() + .setName("listener1") + .setTrafficDirection(TrafficDirection.INBOUND) + .addAllFilterChains(Arrays.asList(filterChain0, filterChain1)) + .setDefaultFilterChain(filterChain0) + .build(); + EnvoyServerProtoData.Listener listener = XdsListenerResource.parseServerSideListener( + listenerProto, null, filterRegistry, null, getXdsResourceTypeArgs(true)); + + assertThat(listener.filterChains().get(0).name()).isEqualTo("chain_0"); + assertThat(listener.filterChains().get(1).name()).isEqualTo("chain_1"); + assertThat(listener.defaultFilterChain().name()).isEqualTo("chain_default"); + } + + @Test + public void parseFilterChain_duplicateName() throws ResourceInvalidException { + FilterChain filterChain0 = FilterChain.newBuilder() + .setName("filter_chain") .setFilterChainMatch(FilterChainMatch.getDefaultInstance()) + .addFilters(buildHttpConnectionManagerFilter( + HttpFilter.newBuilder() + .setName("http-filter-foo") + .setIsOptional(true) + .setTypedConfig(Any.pack(Router.newBuilder().build())) + .build())) + .build(); + + FilterChain filterChain1 = + FilterChain.newBuilder() + .setName("filter_chain") + .setFilterChainMatch( + FilterChainMatch.newBuilder().addAllSourcePorts(Arrays.asList(443, 8080))) .addFilters(buildHttpConnectionManagerFilter( HttpFilter.newBuilder() .setName("http-filter-bar") @@ -2468,204 +3054,273 @@ public void parseFilterChain_noName() throws ResourceInvalidException { .build())) .build(); - EnvoyServerProtoData.FilterChain parsedFilterChain1 = XdsListenerResource.parseFilterChain( - filterChain1, null, filterRegistry, null, - null); - EnvoyServerProtoData.FilterChain parsedFilterChain2 = XdsListenerResource.parseFilterChain( - filterChain2, null, filterRegistry, null, - null); - assertThat(parsedFilterChain1.name()).isEqualTo(parsedFilterChain2.name()); + Listener listenerProto = + Listener.newBuilder() + .setName("listener1") + .setTrafficDirection(TrafficDirection.INBOUND) + .addAllFilterChains(Arrays.asList(filterChain0, filterChain1)) + .build(); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener( + listenerProto, null, filterRegistry, null, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("Filter chain names must be unique. Found duplicate: filter_chain"); } @Test - public void validateCommonTlsContext_tlsParams() throws ResourceInvalidException { + public void validateCommonTlsContext_tlsParams() { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setTlsParams(TlsParameters.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context with tls_params is not supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo("common-tls-context with tls_params is not supported"); } @Test - public void validateCommonTlsContext_customHandshaker() throws ResourceInvalidException { + public void validateCommonTlsContext_customHandshaker() { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCustomHandshaker(TypedExtensionConfig.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context with custom_handshaker is not supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo( + "common-tls-context with custom_handshaker is not supported"); } @Test - public void validateCommonTlsContext_validationContext() throws ResourceInvalidException { + public void validateCommonTlsContext_validationContext() { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setValidationContext(CertificateValidationContext.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("ca_certificate_provider_instance is required in upstream-tls-context"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo( + "ca_certificate_provider_instance or system_root_certs is required " + + "in upstream-tls-context"); } @Test - public void validateCommonTlsContext_validationContextSdsSecretConfig() - throws ResourceInvalidException { + public void validateCommonTlsContext_validationContextSdsSecretConfig() { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setValidationContextSdsSecretConfig(SdsSecretConfig.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo( "common-tls-context with validation_context_sds_secret_config is not supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test - @SuppressWarnings("deprecation") - public void validateCommonTlsContext_validationContextCertificateProvider() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setValidationContextCertificateProvider( - CommonTlsContext.CertificateProvider.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "common-tls-context with validation_context_certificate_provider is not supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); - } - - @Test - @SuppressWarnings("deprecation") - public void validateCommonTlsContext_validationContextCertificateProviderInstance() + public void validateCommonTlsContext_tlsCertificateProviderInstance_isRequiredForServer() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "common-tls-context with validation_context_certificate_provider_instance is not " - + "supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, true)); + assertThat(e).hasMessageThat().isEqualTo( + "tls_certificate_provider_instance is required in downstream-tls-context"); } @Test - public void validateCommonTlsContext_tlsCertificateProviderInstance_isRequiredForServer() + public void validateCommonTlsContext_tlsNewCertificateProviderInstance() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName("name1")) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "tls_certificate_provider_instance is required in downstream-tls-context"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, true); + XdsClusterResource + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); } @Test @SuppressWarnings("deprecation") - public void validateCommonTlsContext_tlsNewCertificateProviderInstance() + public void validateCommonTlsContext_tlsDeprecatedCertificateProviderInstance() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setTlsCertificateProviderInstance( - CertificateProviderPluginInstance.newBuilder().setInstanceName("name1").build()) + .setTlsCertificateCertificateProviderInstance( + CommonTlsContext.CertificateProviderInstance.newBuilder().setInstanceName("name1")) .build(); XdsClusterResource .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_tlsCertificateProviderInstance() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setTlsCertificateCertificateProviderInstance( - CertificateProviderInstance.newBuilder().setInstanceName("name1").build()) + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName("name1")) .build(); XdsClusterResource .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_tlsCertificateProviderInstance_absentInBootstrapFile() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setTlsCertificateCertificateProviderInstance( - CertificateProviderInstance.newBuilder().setInstanceName("bad-name").build()) + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName("bad-name")) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, + ImmutableSet.of("name1", "name2"), true)); + assertThat(e).hasMessageThat().isEqualTo( "CertificateProvider instance name 'bad-name' not defined in the bootstrap file."); - XdsClusterResource - .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_validationContextProviderInstance() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CertificateProviderInstance.newBuilder().setInstanceName("name1").build()) - .build()) + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance(CertificateProviderPluginInstance.newBuilder() + .setInstanceName("name1")))) .build(); XdsClusterResource .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), false); } + @Test + public void + validateCommonTlsContext_combinedValidationContextSystemRootCerts_envVarNotSet_throws() { + XdsClusterResource.enableSystemRootCerts = false; + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setCombinedValidationContext( + CommonTlsContext.CombinedCertificateValidationContext.newBuilder() + .setDefaultValidationContext( + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .build() + ) + .build()) + .build(); + try { + XdsClusterResource + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of(), false); + fail("Expected exception"); + } catch (ResourceInvalidException ex) { + assertThat(ex.getMessage()).isEqualTo( + "ca_certificate_provider_instance or system_root_certs is required in" + + " upstream-tls-context"); + } + } + + @Test + public void validateCommonTlsContext_combinedValidationContextSystemRootCerts() + throws ResourceInvalidException { + XdsClusterResource.enableSystemRootCerts = true; + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setCombinedValidationContext( + CommonTlsContext.CombinedCertificateValidationContext.newBuilder() + .setDefaultValidationContext( + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .build() + ) + .build()) + .build(); + XdsClusterResource + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of(), false); + } + @Test @SuppressWarnings("deprecation") - public void validateCommonTlsContext_validationContextProviderInstance_absentInBootstrapFile() - throws ResourceInvalidException { + public void validateCommonTlsContext_combinedValidationContextDeprecatedCertProvider() + throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName("cert1")) .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() .setValidationContextCertificateProviderInstance( - CertificateProviderInstance.newBuilder().setInstanceName("bad-name").build()) + CommonTlsContext.CertificateProviderInstance.newBuilder() + .setInstanceName("root1")) .build()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "ca_certificate_provider_instance name 'bad-name' not defined in the bootstrap file."); XdsClusterResource - .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), false); + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("cert1", "root1"), true); } - @Test - public void validateCommonTlsContext_tlsCertificatesCount() throws ResourceInvalidException { + public void validateCommonTlsContext_validationContextSystemRootCerts_envVarNotSet_throws() { + XdsClusterResource.enableSystemRootCerts = false; CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .addTlsCertificates(TlsCertificate.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("tls_certificate_provider_instance is unset"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + .setValidationContext( + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .build()) + .build(); + try { + XdsClusterResource + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of(), false); + fail("Expected exception"); + } catch (ResourceInvalidException ex) { + assertThat(ex.getMessage()).isEqualTo( + "ca_certificate_provider_instance or system_root_certs is required in " + + "upstream-tls-context"); + } } @Test - public void validateCommonTlsContext_tlsCertificateSdsSecretConfigsCount() + public void validateCommonTlsContext_validationContextSystemRootCerts() throws ResourceInvalidException { + XdsClusterResource.enableSystemRootCerts = true; CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .addTlsCertificateSdsSecretConfigs(SdsSecretConfig.getDefaultInstance()) + .setValidationContext( + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .build()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "tls_certificate_provider_instance is unset"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of(), false); } @Test - @SuppressWarnings("deprecation") - public void validateCommonTlsContext_tlsCertificateCertificateProvider() + public void validateCommonTlsContext_validationContextProviderInstance_absentInBootstrapFile() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setCombinedValidationContext( + CommonTlsContext.CombinedCertificateValidationContext.newBuilder() + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance(CertificateProviderPluginInstance.newBuilder() + .setInstanceName("bad-name")))) + .build(); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, + ImmutableSet.of("name1", "name2"), false)); + assertThat(e).hasMessageThat().isEqualTo( + "ca_certificate_provider_instance name 'bad-name' not defined in the bootstrap file."); + } + + + @Test + public void validateCommonTlsContext_tlsCertificatesCount() throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .addTlsCertificates(TlsCertificate.getDefaultInstance()) + .build(); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo("tls_certificate_provider_instance is unset"); + } + + @Test + public void validateCommonTlsContext_tlsCertificateSdsSecretConfigsCount() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setTlsCertificateCertificateProvider( - CommonTlsContext.CertificateProvider.getDefaultInstance()) + .addTlsCertificateSdsSecretConfigs(SdsSecretConfig.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo( "tls_certificate_provider_instance is unset"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2673,9 +3328,11 @@ public void validateCommonTlsContext_combinedValidationContext_isRequiredForClie throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("ca_certificate_provider_instance is required in upstream-tls-context"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo( + "ca_certificate_provider_instance or system_root_certs is required " + + "in upstream-tls-context"); } @Test @@ -2685,10 +3342,11 @@ public void validateCommonTlsContext_combinedValidationContextWithoutCertProvide .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "ca_certificate_provider_instance is required in upstream-tls-context"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo( + "ca_certificate_provider_instance or system_root_certs is required in " + + "upstream-tls-context"); } @Test @@ -2698,174 +3356,169 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextForS CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CertificateProviderInstance.getDefaultInstance()) .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) .addMatchSubjectAltNames(StringMatcher.newBuilder().setExact("foo.com").build()) .build())) - .setTlsCertificateCertificateProviderInstance( - CertificateProviderInstance.getDefaultInstance()) + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("match_subject_alt_names only allowed in upstream_tls_context"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), true); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), true)); + assertThat(e).hasMessageThat().isEqualTo( + "match_subject_alt_names only allowed in upstream_tls_context"); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDefaultValContextVerifyCertSpki() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .setDefaultValidationContext( - CertificateValidationContext.newBuilder().addVerifyCertificateSpki("foo"))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) + .addVerifyCertificateSpki("foo"))) + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("verify_certificate_spki in default_validation_context is not " - + "supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false)); + assertThat(e).hasMessageThat().isEqualTo( + "verify_certificate_spki in default_validation_context is not supported"); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDefaultValContextVerifyCertHash() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .setDefaultValidationContext( - CertificateValidationContext.newBuilder().addVerifyCertificateHash("foo"))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) + .addVerifyCertificateHash("foo"))) + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("verify_certificate_hash in default_validation_context is not " - + "supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false)); + assertThat(e).hasMessageThat().isEqualTo( + "verify_certificate_hash in default_validation_context is not supported"); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextDfltValContextRequireSignedCertTimestamp() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) .setRequireSignedCertificateTimestamp(BoolValue.of(true)))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false)); + assertThat(e).hasMessageThat().isEqualTo( "require_signed_certificate_timestamp in default_validation_context is not " + "supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValidationContextWithDefaultValidationContextCrl() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) .setCrl(DataSource.getDefaultInstance()))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("crl in default_validation_context is not supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false)); + assertThat(e).hasMessageThat().isEqualTo("crl in default_validation_context is not supported"); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDfltValContextCustomValidatorConfig() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) .setCustomValidatorConfig(TypedExtensionConfig.getDefaultInstance()))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("custom_validator_config in default_validation_context is not " - + "supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false)); + assertThat(e).hasMessageThat().isEqualTo( + "custom_validator_config in default_validation_context is not supported"); } @Test public void validateDownstreamTlsContext_noCommonTlsContext() throws ResourceInvalidException { DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.getDefaultInstance(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context is required in downstream-tls-context"); - XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, null)); + assertThat(e).hasMessageThat().isEqualTo( + "common-tls-context is required in downstream-tls-context"); } @Test - @SuppressWarnings("deprecation") public void validateDownstreamTlsContext_hasRequireSni() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()))) + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.getDefaultInstance()) .build(); DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.newBuilder() .setCommonTlsContext(commonTlsContext) .setRequireSni(BoolValue.of(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("downstream-tls-context with require-sni is not supported"); - XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, ImmutableSet.of("")); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, + ImmutableSet.of(""))); + assertThat(e).hasMessageThat().isEqualTo( + "downstream-tls-context with require-sni is not supported"); } @Test - @SuppressWarnings("deprecation") public void validateDownstreamTlsContext_hasOcspStaplePolicy() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()))) + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.getDefaultInstance()) .build(); DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.newBuilder() .setCommonTlsContext(commonTlsContext) .setOcspStaplePolicy(DownstreamTlsContext.OcspStaplePolicy.STRICT_STAPLING) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, + ImmutableSet.of(""))); + assertThat(e).hasMessageThat().isEqualTo( "downstream-tls-context with ocsp_staple_policy value STRICT_STAPLING is not supported"); - XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, ImmutableSet.of("")); } @Test public void validateUpstreamTlsContext_noCommonTlsContext() throws ResourceInvalidException { UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext.getDefaultInstance(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context is required in upstream-tls-context"); - XdsClusterResource.validateUpstreamTlsContext(upstreamTlsContext, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateUpstreamTlsContext(upstreamTlsContext, null)); + assertThat(e).hasMessageThat().isEqualTo( + "common-tls-context is required in upstream-tls-context"); } @Test @@ -2917,7 +3570,7 @@ public void canonifyResourceName() { /** * Tests compliance with RFC 3986 section 3.3 - * https://datatracker.ietf.org/doc/html/rfc3986#section-3.3 + * https://datatracker.ietf.org/doc/html/rfc3986#section-3.3 . */ @Test public void percentEncodePath() { @@ -2957,4 +3610,10 @@ private static Filter buildHttpConnectionManagerFilter(HttpFilter... httpFilters "type.googleapis.com")) .build(); } + + private XdsResourceType.Args getXdsResourceTypeArgs(boolean isTrustedServer) { + return new XdsResourceType.Args( + ServerInfo.create("http://td", "", false, isTrustedServer, false, false), "1.0", null, null, null, null + ); + } } diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java index d41630cdb4a..af55e572811 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java @@ -18,14 +18,16 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; +import static io.grpc.StatusMatcher.statusHasCode; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -41,6 +43,7 @@ import com.google.protobuf.Duration; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; +import com.google.protobuf.StringValue; import com.google.protobuf.UInt32Value; import com.google.protobuf.util.Durations; import io.envoyproxy.envoy.config.cluster.v3.OutlierDetection; @@ -48,7 +51,6 @@ import io.envoyproxy.envoy.config.route.v3.WeightedCluster; import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.grpc.BindableService; import io.grpc.ChannelCredentials; import io.grpc.Context; @@ -58,6 +60,8 @@ import io.grpc.Server; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.StatusOr; +import io.grpc.StatusOrMatcher; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.BackoffPolicy; @@ -87,6 +91,7 @@ import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.BootstrapperImpl; import io.grpc.xds.client.EnvoyProtoData.Node; import io.grpc.xds.client.LoadStatsManager2.ClusterDropStats; import io.grpc.xds.client.Locality; @@ -95,7 +100,9 @@ import io.grpc.xds.client.XdsClient.ResourceMetadata.UpdateFailureState; import io.grpc.xds.client.XdsClient.ResourceUpdate; import io.grpc.xds.client.XdsClient.ResourceWatcher; +import io.grpc.xds.client.XdsClient.ServerConnectionCallback; import io.grpc.xds.client.XdsClientImpl; +import io.grpc.xds.client.XdsClientMetricReporter; import io.grpc.xds.client.XdsResourceType; import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; import io.grpc.xds.client.XdsTransportFactory; @@ -107,11 +114,13 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Queue; import java.util.concurrent.BlockingDeque; import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executor; +import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -124,7 +133,6 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatchers; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -142,8 +150,9 @@ // The base class was used to test both xds v2 and v3. V2 is dropped now so the base class is not // necessary. Still keep it for future version usage. Remove if too much trouble to maintain. public abstract class GrpcXdsClientImplTestBase { + private static final String SERVER_URI = "trafficdirector.googleapis.com"; - private static final String SERVER_URI_CUSTOME_AUTHORITY = "trafficdirector2.googleapis.com"; + private static final String SERVER_URI_CUSTOM_AUTHORITY = "trafficdirector2.googleapis.com"; private static final String SERVER_URI_EMPTY_AUTHORITY = "trafficdirector3.googleapis.com"; private static final String LDS_RESOURCE = "listener.googleapis.com"; private static final String RDS_RESOURCE = "route-configuration.googleapis.com"; @@ -217,7 +226,7 @@ public boolean shouldAccept(Runnable command) { protected final Queue loadReportCalls = new ArrayDeque<>(); protected final AtomicBoolean adsEnded = new AtomicBoolean(true); protected final AtomicBoolean lrsEnded = new AtomicBoolean(true); - private final MessageFactory mf = createMessageFactory(); + protected MessageFactory mf; private static final long TIME_INCREMENT = TimeUnit.SECONDS.toNanos(1); /** Fake time provider increments time TIME_INCREMENT each call. */ @@ -231,47 +240,30 @@ public long currentTimeNanos() { private static final int VHOST_SIZE = 2; // LDS test resources. - private final Any testListenerVhosts = Any.pack(mf.buildListenerWithApiListener(LDS_RESOURCE, - mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); - private final Any testListenerRds = - Any.pack(mf.buildListenerWithApiListenerForRds(LDS_RESOURCE, RDS_RESOURCE)); + private Any testListenerVhosts; + private Any testListenerRds; // RDS test resources. - private final Any testRouteConfig = - Any.pack(mf.buildRouteConfiguration(RDS_RESOURCE, mf.buildOpaqueVirtualHosts(VHOST_SIZE))); + private Any testRouteConfig; // CDS test resources. - private final Any testClusterRoundRobin = - Any.pack(mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, - null, false, null, "envoy.transport_sockets.tls", null, null - )); + private Any testClusterRoundRobin; // EDS test resources. - private final Message lbEndpointHealthy = - mf.buildLocalityLbEndpoints("region1", "zone1", "subzone1", - mf.buildLbEndpoint("192.168.0.1", 8080, "healthy", 2), 1, 0); + private Message lbEndpointHealthy; // Locality with 0 endpoints - private final Message lbEndpointEmpty = - mf.buildLocalityLbEndpoints("region3", "zone3", "subzone3", - ImmutableList.of(), 2, 1); + private Message lbEndpointEmpty; // Locality with 0-weight endpoint - private final Message lbEndpointZeroWeight = - mf.buildLocalityLbEndpoints("region4", "zone4", "subzone4", - mf.buildLbEndpoint("192.168.142.5", 80, "unknown", 5), 0, 2); - private final Any testClusterLoadAssignment = Any.pack(mf.buildClusterLoadAssignment(EDS_RESOURCE, - ImmutableList.of(lbEndpointHealthy, lbEndpointEmpty, lbEndpointZeroWeight), - ImmutableList.of(mf.buildDropOverload("lb", 200), mf.buildDropOverload("throttle", 1000)))); - - @Captor - private ArgumentCaptor ldsUpdateCaptor; + private Message lbEndpointZeroWeight; + private Any testClusterLoadAssignment; @Captor - private ArgumentCaptor rdsUpdateCaptor; + private ArgumentCaptor> ldsUpdateCaptor; @Captor - private ArgumentCaptor cdsUpdateCaptor; + private ArgumentCaptor> rdsUpdateCaptor; @Captor - private ArgumentCaptor edsUpdateCaptor; + private ArgumentCaptor> cdsUpdateCaptor; @Captor - private ArgumentCaptor errorCaptor; + private ArgumentCaptor> edsUpdateCaptor; @Mock private BackoffPolicy.Provider backoffPolicyProvider; @@ -282,11 +274,19 @@ public long currentTimeNanos() { @Mock private ResourceWatcher ldsResourceWatcher; @Mock + private ResourceWatcher ldsResourceWatcher2; + @Mock private ResourceWatcher rdsResourceWatcher; @Mock private ResourceWatcher cdsResourceWatcher; @Mock private ResourceWatcher edsResourceWatcher; + @Mock + private ResourceWatcher stringResourceWatcher; + @Mock + private XdsClientMetricReporter xdsClientMetricReporter; + @Mock + private ServerConnectionCallback serverConnectionCallback; private ManagedChannel channel; private ManagedChannel channelForCustomAuthority; @@ -295,11 +295,61 @@ public long currentTimeNanos() { private boolean originalEnableLeastRequest; private Server xdsServer; private final String serverName = InProcessServerBuilder.generateName(); - private final BindableService adsService = createAdsService(); - private final BindableService lrsService = createLrsService(); + private BindableService adsService; + private BindableService lrsService; + + private XdsTransportFactory xdsTransportFactory = new XdsTransportFactory() { + @Override + public XdsTransport create(ServerInfo serverInfo) { + if (serverInfo.target().equals(SERVER_URI)) { + return new GrpcXdsTransport(channel); + } + if (serverInfo.target().equals(SERVER_URI_CUSTOM_AUTHORITY)) { + if (channelForCustomAuthority == null) { + channelForCustomAuthority = cleanupRule.register( + InProcessChannelBuilder.forName(serverName).directExecutor().build()); + } + return new GrpcXdsTransport(channelForCustomAuthority); + } + if (serverInfo.target().equals(SERVER_URI_EMPTY_AUTHORITY)) { + if (channelForEmptyAuthority == null) { + channelForEmptyAuthority = cleanupRule.register( + InProcessChannelBuilder.forName(serverName).directExecutor().build()); + } + return new GrpcXdsTransport(channelForEmptyAuthority); + } + throw new IllegalArgumentException("Can not create channel for " + serverInfo); + } + }; @Before public void setUp() throws IOException { + mf = createMessageFactory(); + testListenerVhosts = Any.pack(mf.buildListenerWithApiListener(LDS_RESOURCE, + mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); + testListenerRds = + Any.pack(mf.buildListenerWithApiListenerForRds(LDS_RESOURCE, RDS_RESOURCE)); + testRouteConfig = + Any.pack(mf.buildRouteConfiguration(RDS_RESOURCE, mf.buildOpaqueVirtualHosts(VHOST_SIZE))); + testClusterRoundRobin = + Any.pack(mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, + null, false, null, "envoy.transport_sockets.tls", null, null + )); + lbEndpointHealthy = + mf.buildLocalityLbEndpoints("region1", "zone1", "subzone1", + mf.buildLbEndpoint("192.168.0.1", 8080, "healthy", 2, "endpoint-host-name"), 1, 0); + lbEndpointEmpty = + mf.buildLocalityLbEndpoints("region3", "zone3", "subzone3", + ImmutableList.of(), 2, 1); + lbEndpointZeroWeight = + mf.buildLocalityLbEndpoints("region4", "zone4", "subzone4", + mf.buildLbEndpoint("192.168.142.5", 80, "unknown", 5, "endpoint-host-name"), 0, 2); + testClusterLoadAssignment = Any.pack(mf.buildClusterLoadAssignment(EDS_RESOURCE, + ImmutableList.of(lbEndpointHealthy, lbEndpointEmpty, lbEndpointZeroWeight), + ImmutableList.of(mf.buildDropOverload("lb", 200), mf.buildDropOverload("throttle", 1000)))); + adsService = createAdsService(); + lrsService = createLrsService(); + when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2); when(backoffPolicy1.nextBackoffNanos()).thenReturn(10L, 100L); when(backoffPolicy2.nextBackoffNanos()).thenReturn(20L, 200L); @@ -316,31 +366,9 @@ public void setUp() throws IOException { .start()); channel = cleanupRule.register(InProcessChannelBuilder.forName(serverName).directExecutor().build()); - XdsTransportFactory xdsTransportFactory = new XdsTransportFactory() { - @Override - public XdsTransport create(ServerInfo serverInfo) { - if (serverInfo.target().equals(SERVER_URI)) { - return new GrpcXdsTransport(channel); - } - if (serverInfo.target().equals(SERVER_URI_CUSTOME_AUTHORITY)) { - if (channelForCustomAuthority == null) { - channelForCustomAuthority = cleanupRule.register( - InProcessChannelBuilder.forName(serverName).directExecutor().build()); - } - return new GrpcXdsTransport(channelForCustomAuthority); - } - if (serverInfo.target().equals(SERVER_URI_EMPTY_AUTHORITY)) { - if (channelForEmptyAuthority == null) { - channelForEmptyAuthority = cleanupRule.register( - InProcessChannelBuilder.forName(serverName).directExecutor().build()); - } - return new GrpcXdsTransport(channelForEmptyAuthority); - } - throw new IllegalArgumentException("Can not create channel for " + serverInfo); - } - }; - xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, ignoreResourceDeletion()); + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, ignoreResourceDeletion(), + true, false, false); BootstrapInfo bootstrapInfo = Bootstrapper.BootstrapInfo.builder() .servers(Collections.singletonList(xdsServerInfo)) @@ -350,7 +378,7 @@ public XdsTransport create(ServerInfo serverInfo) { AuthorityInfo.create( "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/%s", ImmutableList.of(Bootstrapper.ServerInfo.create( - SERVER_URI_CUSTOME_AUTHORITY, CHANNEL_CREDENTIALS))), + SERVER_URI_CUSTOM_AUTHORITY, CHANNEL_CREDENTIALS))), "", AuthorityInfo.create( "xdstp:///envoy.config.listener.v3.Listener/%s", @@ -368,7 +396,8 @@ public XdsTransport create(ServerInfo serverInfo) { fakeClock.getStopwatchSupplier(), timeProvider, MessagePrinter.INSTANCE, - new TlsContextManagerImpl(bootstrapInfo)); + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); assertThat(resourceDiscoveryCalls).isEmpty(); assertThat(loadReportCalls).isEmpty(); @@ -475,10 +504,11 @@ private void verifyResourceMetadataAcked( private void verifyResourceMetadataNacked( XdsResourceType type, String resourceName, Any rawResource, String versionInfo, long updateTime, String failedVersion, long failedUpdateTimeNanos, - List failedDetails) { + List failedDetails, boolean cached) { ResourceMetadata resourceMetadata = verifyResourceMetadata(type, resourceName, rawResource, ResourceMetadataStatus.NACKED, versionInfo, updateTime, true); + assertThat(resourceMetadata.isCached()).isEqualTo(cached); UpdateFailureState errorState = resourceMetadata.getErrorState(); assertThat(errorState).isNotNull(); @@ -595,9 +625,104 @@ private void validateGoldenClusterLoadAssignment(EdsUpdate edsUpdate) { .containsExactly( Locality.create("region1", "zone1", "subzone1"), LocalityLbEndpoints.create( - ImmutableList.of(LbEndpoint.create("192.168.0.1", 8080, 2, true)), 1, 0), + ImmutableList.of(LbEndpoint.create("192.168.0.1", 8080, 2, true, + "endpoint-host-name", ImmutableMap.of())), 1, 0, ImmutableMap.of()), Locality.create("region3", "zone3", "subzone3"), - LocalityLbEndpoints.create(ImmutableList.of(), 2, 1)); + LocalityLbEndpoints.create(ImmutableList.of(), 2, 1, ImmutableMap.of())); + } + + /** + * Verifies that the {@link XdsClientMetricReporter#reportResourceUpdates} method has been called + * the expected number of times with the expected values for valid resource count, invalid + * resource count, and corresponding metric labels. + */ + private void verifyResourceValidInvalidCount(int times, long validResourceCount, + long invalidResourceCount, String xdsServerTargetLabel, + String resourceType) { + verify(xdsClientMetricReporter, times(times)).reportResourceUpdates( + eq(validResourceCount), + eq(invalidResourceCount), + eq(xdsServerTargetLabel), + eq(resourceType)); + } + + private void verifyServerFailureCount(int times, long serverFailureCount, String xdsServer) { + verify(xdsClientMetricReporter, times(times)).reportServerFailure( + eq(serverFailureCount), + eq(xdsServer)); + } + + /** + * Invokes the callback, which will be called by {@link XdsClientMetricReporter} to record + * whether XdsClient has a working ADS stream. + */ + private void callback_ReportServerConnection() { + try { + Future unused = xdsClient.reportServerConnections(serverConnectionCallback); + } catch (Exception e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new AssertionError(e); + } + } + + private void verifyServerConnection(int times, boolean isConnected, String xdsServer) { + verify(serverConnectionCallback, times(times)).reportServerConnectionGauge( + eq(isConnected), + eq(xdsServer)); + } + + @Test + public void doParse_returnsSuccessfully() { + XdsStringResource resourceType = new XdsStringResource(); + xdsClient.watchXdsResource( + resourceType, "resource1", stringResourceWatcher, MoreExecutors.directExecutor()); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + + Any resource = Any.pack(StringValue.newBuilder().setValue("resource1").build()); + call.sendResponse(resourceType, resource, VERSION_1, "0000"); + verify(stringResourceWatcher).onResourceChanged(argThat(StatusOrMatcher.hasValue( + (StringUpdate arg) -> new StringUpdate("resource1").equals(arg)))); + } + + @Test + public void doParse_throwsResourceInvalidException_resourceInvalid() { + XdsStringResource resourceType = new XdsStringResource() { + @Override + protected StringUpdate doParse(Args args, Message unpackedMessage) + throws ResourceInvalidException { + throw new ResourceInvalidException("some bad input"); + } + }; + xdsClient.watchXdsResource( + resourceType, "resource1", stringResourceWatcher, MoreExecutors.directExecutor()); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + + Any resource = Any.pack(StringValue.newBuilder().setValue("resource1").build()); + call.sendResponse(resourceType, resource, VERSION_1, "0000"); + verify(stringResourceWatcher).onResourceChanged(argThat(StatusOrMatcher.hasStatus( + statusHasCode(Status.Code.UNAVAILABLE) + .andDescriptionContains("validation error: some bad input")))); + } + + @Test + public void doParse_throwsError_resourceInvalid() throws Exception { + XdsStringResource resourceType = new XdsStringResource() { + @Override + protected StringUpdate doParse(Args args, Message unpackedMessage) { + throw new AssertionError("something bad happened"); + } + }; + xdsClient.watchXdsResource( + resourceType, "resource1", stringResourceWatcher, MoreExecutors.directExecutor()); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + + Any resource = Any.pack(StringValue.newBuilder().setValue("resource1").build()); + call.sendResponse(resourceType, resource, VERSION_1, "0000"); + verify(stringResourceWatcher).onResourceChanged(argThat(StatusOrMatcher.hasStatus( + statusHasCode(Status.Code.UNAVAILABLE) + .andDescriptionContains("unexpected error: AssertionError: something bad happened")))); } @Test @@ -616,9 +741,13 @@ public void ldsResourceNotFound() { verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); // Server failed to return subscribed resource within expected time window. fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isFalse(); + assertThat(statusOrUpdate.getStatus().getCode()).isEqualTo(Status.Code.NOT_FOUND); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); + // Check metric data. verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); } @@ -626,16 +755,41 @@ public void ldsResourceNotFound() { public void ldsResourceUpdated_withXdstpResourceName_withUnknownAuthority() { String ldsResourceName = "xdstp://unknown.example.com/envoy.config.listener.v3.Listener/listener1"; - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),ldsResourceName, + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), ldsResourceName, + ldsResourceWatcher); + verify(ldsResourceWatcher).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.INVALID_ARGUMENT + && statusOr.getStatus().getDescription().equals( + "Wrong configuration: xds server does not exist for resource " + ldsResourceName))); + assertThat(resourceDiscoveryCalls.poll()).isNull(); + xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(), ldsResourceName, ldsResourceWatcher); - verify(ldsResourceWatcher).onError(errorCaptor.capture()); - Status error = errorCaptor.getValue(); - assertThat(error.getCode()).isEqualTo(Code.INVALID_ARGUMENT); - assertThat(error.getDescription()).isEqualTo( - "Wrong configuration: xds server does not exist for resource " + ldsResourceName); assertThat(resourceDiscoveryCalls.poll()).isNull(); - xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(),ldsResourceName, + } + + @Test + public void ldsResource_onError_cachedForNewWatcher() { + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + call.sendCompleted(); + @SuppressWarnings("unchecked") + ArgumentCaptor> errorCaptor = + ArgumentCaptor.forClass(StatusOr.class); + verify(ldsResourceWatcher, timeout(1000)).onResourceChanged(errorCaptor.capture()); + StatusOr initialError = errorCaptor.getValue(); + assertThat(initialError.hasValue()).isFalse(); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher2); + @SuppressWarnings("unchecked") + ArgumentCaptor> secondErrorCaptor = + ArgumentCaptor.forClass(StatusOr.class); + verify(ldsResourceWatcher2, timeout(1000)).onResourceChanged(secondErrorCaptor.capture()); + StatusOr cachedError = secondErrorCaptor.getValue(); + + assertThat(cachedError).isEqualTo(initialError); assertThat(resourceDiscoveryCalls.poll()).isNull(); } @@ -674,21 +828,22 @@ public void ldsResponseErrorHandling_someResourcesFailedUnpack() { verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); // The response is NACKed with the same error message. call.verifyRequestNack(LDS, LDS_RESOURCE, "", "0000", NODE, errors); - verify(ldsResourceWatcher).onChanged(any(LdsUpdate.class)); + verify(ldsResourceWatcher).onResourceChanged(any()); } /** * Tests a subscribed LDS resource transitioned to and from the invalid state. * - * @see - * A40-csds-support.md + * @see + * A40-csds-support.md */ @Test public void ldsResponseErrorHandling_subscribedResourceInvalid() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"A", ldsResourceWatcher); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"B", ldsResourceWatcher); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"C", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "A", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "B", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "C", ldsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(LDS, "A"); @@ -706,6 +861,8 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { verifyResourceMetadataAcked(LDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(LDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + // Check metric data. + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), LDS.typeUrl()); call.verifyRequest(LDS, subscribedResourceNames, VERSION_1, "0000", NODE); // LDS -> {A, B}, version 2 @@ -720,7 +877,9 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { List errorsV2 = ImmutableList.of("LDS response Listener 'B' validation error: "); verifyResourceMetadataAcked(LDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + VERSION_2, TIME_INCREMENT * 2, errorsV2, true); + // Check metric data. + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), LDS.typeUrl()); if (!ignoreResourceDeletion()) { verifyResourceMetadataDoesNotExist(LDS, "C"); } else { @@ -736,6 +895,8 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { call.sendResponse(LDS, resourcesV3.values().asList(), VERSION_3, "0002"); // {A} -> does not exist // {B, C} -> ACK, version 3 + // Check metric data. + verifyResourceValidInvalidCount(1, 2, 0, xdsServerInfo.target(), LDS.typeUrl()); if (!ignoreResourceDeletion()) { verifyResourceMetadataDoesNotExist(LDS, "A"); } else { @@ -748,15 +909,61 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { verifySubscribedResourcesMetadataSizes(3, 0, 0, 0); } + @Test + public void ldsResponseErrorHandling_subscribedResourceInvalid_withDataErrorHandlingEnabled() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "A", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "B", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "C", ldsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + assertThat(call).isNotNull(); + verifyResourceMetadataRequested(LDS, "A"); + verifyResourceMetadataRequested(LDS, "B"); + verifyResourceMetadataRequested(LDS, "C"); + ImmutableMap resourcesV1 = ImmutableMap.of( + "A", Any.pack(mf.buildListenerWithApiListenerForRds("A", "A.1")), + "B", Any.pack(mf.buildListenerWithApiListenerForRds("B", "B.1")), + "C", Any.pack(mf.buildListenerWithApiListenerForRds("C", "C.1"))); + call.sendResponse(LDS, resourcesV1.values().asList(), VERSION_1, "0000"); + verify(ldsResourceWatcher, times(3)).onResourceChanged(any()); + ImmutableMap resourcesV2 = ImmutableMap.of( + "A", Any.pack(mf.buildListenerWithApiListenerForRds("A", "A.2")), + "B", Any.pack(mf.buildListenerWithApiListenerInvalid("B"))); + call.sendResponse(LDS, resourcesV2.values().asList(), VERSION_2, "0001"); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(ldsResourceWatcher, times(2)).onAmbientError(statusCaptor.capture()); + List receivedStatuses = statusCaptor.getAllValues(); + assertThat(receivedStatuses).hasSize(2); + + assertThat( + receivedStatuses.stream().anyMatch( + status -> status.getCode() == Status.Code.UNAVAILABLE + && status.getDescription().contains("LDS response Listener 'B' validation error"))) + .isTrue(); + assertThat( + receivedStatuses.stream().anyMatch( + status -> status.getCode() == Status.Code.NOT_FOUND + && status.getDescription().contains("Resource C deleted from server"))) + .isTrue(); + List errorsV2 = ImmutableList.of("LDS response Listener 'B' validation error: "); + verifyResourceMetadataAcked(LDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); + verifyResourceMetadataNacked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, + VERSION_2, TIME_INCREMENT * 2, errorsV2, true); + verifyResourceMetadataAcked(LDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + @Test public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscription() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"A", ldsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"A.1", rdsResourceWatcher); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"B", ldsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"B.1", rdsResourceWatcher); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"C", ldsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"C.1", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "A", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "A.1", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "B", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "B.1", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "C", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "C.1", rdsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(LDS, "A"); @@ -774,6 +981,7 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscripti "C", Any.pack(mf.buildListenerWithApiListenerForRds("C", "C.1"))); call.sendResponse(LDS, resourcesV1.values().asList(), VERSION_1, "0000"); // {A, B, C} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), LDS.typeUrl()); verifyResourceMetadataAcked(LDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(LDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); @@ -790,6 +998,8 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscripti verifyResourceMetadataAcked(RDS, "A.1", resourcesV11.get("A.1"), VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked(RDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked(RDS, "C.1", resourcesV11.get("C.1"), VERSION_1, TIME_INCREMENT * 2); + // Check metric data. + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), RDS.typeUrl()); // LDS -> {A, B}, version 2 // Failed to parse endpoint B @@ -800,11 +1010,13 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscripti // {A} -> ACK, version 2 // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> does not exist + // Check metric data. + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), LDS.typeUrl()); List errorsV2 = ImmutableList.of("LDS response Listener 'B' validation error: "); verifyResourceMetadataAcked(LDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 3); verifyResourceMetadataNacked( LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 3, - errorsV2); + errorsV2, true); if (!ignoreResourceDeletion()) { verifyResourceMetadataDoesNotExist(LDS, "C"); } else { @@ -830,8 +1042,10 @@ public void ldsResourceFound_containsVirtualHosts() { // Client sends an ACK LDS request. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -845,8 +1059,10 @@ public void wrappedLdsResource() { // Client sends an ACK LDS request. call.sendResponse(LDS, mf.buildWrappedResource(testListenerVhosts), VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -858,14 +1074,16 @@ public void wrappedLdsResource_preferWrappedName() { ldsResourceWatcher); Any innerResource = Any.pack(mf.buildListenerWithApiListener("random_name" /* name */, - mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); + mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); // Client sends an ACK LDS request. call.sendResponse(LDS, mf.buildWrappedResourceWithName(innerResource, LDS_RESOURCE), VERSION_1, - "0000"); + "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, innerResource, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -879,8 +1097,10 @@ public void ldsResourceFound_containsRdsName() { // Client sends an ACK LDS request. call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerRds(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -897,9 +1117,11 @@ public void cachedLdsResource_data() { call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, watcher); - verify(watcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerRds(statusOrUpdate.getValue()); call.verifyNoMoreRequest(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -911,11 +1133,17 @@ public void cachedLdsResource_absent() { DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isFalse(); + assertThat(statusOrUpdate.getStatus().getCode()).isEqualTo(Status.Code.NOT_FOUND); // Add another watcher. ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, watcher); - verify(watcher).onResourceDoesNotExist(LDS_RESOURCE); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate1 = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate1.hasValue()).isFalse(); + assertThat(statusOrUpdate1.getStatus().getCode()).isEqualTo(Status.Code.NOT_FOUND); call.verifyNoMoreRequest(); verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -930,15 +1158,19 @@ public void ldsResourceUpdated() { // Initial LDS response. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); // Updated LDS response. call.sendResponse(LDS, testListenerRds, VERSION_2, "0001"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); - verify(ldsResourceWatcher, times(2)).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerRds(statusOrUpdate.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_2, TIME_INCREMENT * 2); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); assertThat(channelForCustomAuthority).isNull(); @@ -954,8 +1186,10 @@ public void cancelResourceWatcherNotRemoveUrlSubscribers() { // Initial LDS response. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); xdsClient.watchXdsResource(XdsListenerResource.getInstance(), @@ -968,8 +1202,10 @@ public void cancelResourceWatcherNotRemoveUrlSubscribers() { mf.buildRouteConfiguration("new", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); call.sendResponse(LDS, testListenerVhosts2, VERSION_2, "0001"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts2, VERSION_2, TIME_INCREMENT * 2); } @@ -987,8 +1223,10 @@ public void ldsResourceUpdated_withXdstpResourceName() { mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, ldsResourceName, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyResourceMetadataAcked( LDS, ldsResourceName, testListenerVhosts, VERSION_1, TIME_INCREMENT); } @@ -1005,8 +1243,10 @@ public void ldsResourceUpdated_withXdstpResourceName_withEmptyAuthority() { mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, ldsResourceName, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyResourceMetadataAcked( LDS, ldsResourceName, testListenerVhosts, VERSION_1, TIME_INCREMENT); } @@ -1046,7 +1286,7 @@ public void ldsResourceUpdated_withXdstpResourceName_withWrongType() { call.verifyRequestNack( LDS, ldsResourceName, "", "0000", NODE, ImmutableList.of( - "Unsupported resource name: " + ldsResourceNameWithWrongType + " for type: LDS")); + "Unsupported resource name: " + ldsResourceNameWithWrongType + " for type: LDS")); } @Test @@ -1072,16 +1312,20 @@ public void rdsResourceUpdated_withXdstpResourceName_withWrongType() { public void rdsResourceUpdated_withXdstpResourceName_unknownAuthority() { String rdsResourceName = "xdstp://unknown.example.com/envoy.config.route.v3.RouteConfiguration/route1"; - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),rdsResourceName, + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), rdsResourceName, rdsResourceWatcher); - verify(rdsResourceWatcher).onError(errorCaptor.capture()); - Status error = errorCaptor.getValue(); + @SuppressWarnings("unchecked") + ArgumentCaptor> rdsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr capturedUpdate = rdsUpdateCaptor.getValue(); + assertThat(capturedUpdate.hasValue()).isFalse(); + Status error = capturedUpdate.getStatus(); assertThat(error.getCode()).isEqualTo(Code.INVALID_ARGUMENT); assertThat(error.getDescription()).isEqualTo( "Wrong configuration: xds server does not exist for resource " + rdsResourceName); assertThat(resourceDiscoveryCalls.size()).isEqualTo(0); xdsClient.cancelXdsResourceWatch( - XdsRouteConfigureResource.getInstance(),rdsResourceName, rdsResourceWatcher); + XdsRouteConfigureResource.getInstance(), rdsResourceName, rdsResourceWatcher); assertThat(resourceDiscoveryCalls.size()).isEqualTo(0); } @@ -1107,15 +1351,19 @@ public void cdsResourceUpdated_withXdstpResourceName_withWrongType() { @Test public void cdsResourceUpdated_withXdstpResourceName_unknownAuthority() { String cdsResourceName = "xdstp://unknown.example.com/envoy.config.cluster.v3.Cluster/cluster1"; - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),cdsResourceName, + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), cdsResourceName, cdsResourceWatcher); - verify(cdsResourceWatcher).onError(errorCaptor.capture()); - Status error = errorCaptor.getValue(); + @SuppressWarnings("unchecked") + ArgumentCaptor> cdsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr capturedUpdate = cdsUpdateCaptor.getValue(); + assertThat(capturedUpdate.hasValue()).isFalse(); + Status error = capturedUpdate.getStatus(); assertThat(error.getCode()).isEqualTo(Code.INVALID_ARGUMENT); assertThat(error.getDescription()).isEqualTo( "Wrong configuration: xds server does not exist for resource " + cdsResourceName); assertThat(resourceDiscoveryCalls.poll()).isNull(); - xdsClient.cancelXdsResourceWatch(XdsClusterResource.getInstance(),cdsResourceName, + xdsClient.cancelXdsResourceWatch(XdsClusterResource.getInstance(), cdsResourceName, cdsResourceWatcher); assertThat(resourceDiscoveryCalls.poll()).isNull(); } @@ -1134,7 +1382,7 @@ public void edsResourceUpdated_withXdstpResourceName_withWrongType() { edsResourceNameWithWrongType, ImmutableList.of(mf.buildLocalityLbEndpoints( "region2", "zone2", "subzone2", - mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3), 2, 0)), + mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3, "endpoint-host-name"), 2, 0)), ImmutableList.of())); call.sendResponse(EDS, testEdsConfig, VERSION_1, "0000"); call.verifyRequestNack( @@ -1149,8 +1397,12 @@ public void edsResourceUpdated_withXdstpResourceName_unknownAuthority() { "xdstp://unknown.example.com/envoy.config.endpoint.v3.ClusterLoadAssignment/cluster1"; xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), edsResourceName, edsResourceWatcher); - verify(edsResourceWatcher).onError(errorCaptor.capture()); - Status error = errorCaptor.getValue(); + @SuppressWarnings("unchecked") + ArgumentCaptor> edsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(edsResourceWatcher).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr capturedUpdate = edsUpdateCaptor.getValue(); + assertThat(capturedUpdate.hasValue()).isFalse(); + Status error = capturedUpdate.getStatus(); assertThat(error.getCode()).isEqualTo(Code.INVALID_ARGUMENT); assertThat(error.getDescription()).isEqualTo( "Wrong configuration: xds server does not exist for resource " + edsResourceName); @@ -1199,11 +1451,13 @@ public void ldsResourceUpdate_withFaultInjection() { // Client sends an ACK LDS request. call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, listener, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); - LdsUpdate ldsUpdate = ldsUpdateCaptor.getValue(); + LdsUpdate ldsUpdate = statusOrUpdate.getValue(); assertThat(ldsUpdate.httpConnectionManager().virtualHosts()).hasSize(2); assertThat(ldsUpdate.httpConnectionManager().httpFilterConfigs().get(0).name) .isEqualTo("envoy.fault"); @@ -1228,6 +1482,7 @@ public void ldsResourceUpdate_withFaultInjection() { @Test public void ldsResourceDeleted() { Assume.assumeFalse(ignoreResourceDeletion()); + InOrder inOrder = inOrder(ldsResourceWatcher); DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); @@ -1236,15 +1491,20 @@ public void ldsResourceDeleted() { // Initial LDS response. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); // Empty LDS response deletes the listener. call.sendResponse(LDS, Collections.emptyList(), VERSION_2, "0001"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); - verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate1 = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate1.hasValue()).isFalse(); + assertThat(statusOrUpdate1.getStatus().getCode()).isEqualTo(Status.Code.NOT_FOUND); verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); } @@ -1252,7 +1512,7 @@ public void ldsResourceDeleted() { /** * When ignore_resource_deletion server feature is on, xDS client should keep the deleted listener * on empty response, and resume the normal work when LDS contains the listener again. - * */ + */ @Test public void ldsResourceDeleted_ignoreResourceDeletion() { Assume.assumeTrue(ignoreResourceDeletion()); @@ -1264,8 +1524,8 @@ public void ldsResourceDeleted_ignoreResourceDeletion() { // Initial LDS response. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue().getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -1273,32 +1533,204 @@ public void ldsResourceDeleted_ignoreResourceDeletion() { call.sendResponse(LDS, Collections.emptyList(), VERSION_2, "0001"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); // The resource is still ACKED at VERSION_1 (no changes). + verify(ldsResourceWatcher).onAmbientError( + argThat(status -> status.getCode() == Status.Code.NOT_FOUND)); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); - // onResourceDoesNotExist not called - verify(ldsResourceWatcher, never()).onResourceDoesNotExist(LDS_RESOURCE); // Next update is correct, and contains the listener again. - call.sendResponse(LDS, testListenerVhosts, VERSION_3, "0003"); + Any updatedListener = Any.pack(mf.buildListenerWithApiListener(LDS_RESOURCE, + mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE + 1)))); + call.sendResponse(LDS, updatedListener, VERSION_3, "0003"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_3, "0003", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + assertThat(ldsUpdateCaptor.getValue().getValue().httpConnectionManager().virtualHosts()) + .hasSize(VHOST_SIZE + 1); // LDS is now ACKEd at VERSION_3. - verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_3, - TIME_INCREMENT * 3); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, updatedListener, VERSION_3, TIME_INCREMENT * 3); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); verifyNoMoreInteractions(ldsResourceWatcher); } + /** + * When fail_on_data_errors server feature is on, xDS client should delete the cached listener + * and fail RPCs when LDS resource is deleted. + */ + @Test + public void ldsResourceDeleted_failOnDataErrors_true() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, true); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of()) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + + InOrder inOrder = inOrder(ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + verifyResourceMetadataRequested(LDS, LDS_RESOURCE); + + // Initial LDS response. + call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + + // Empty LDS response deletes the listener and fails RPCs. + call.sendResponse(LDS, Collections.emptyList(), VERSION_2, "0001"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate1 = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate1.hasValue()).isFalse(); + assertThat(statusOrUpdate1.getStatus().getCode()).isEqualTo(Status.Code.NOT_FOUND); + verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + /** + * When the fail_on_data_errors server feature is not present, the default behavior + * is to treat a resource deletion as an ambient error and preserve the cached resource. + */ + @Test + public void ldsResourceDeleted_failOnDataErrors_false() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, false); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of()) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + + InOrder inOrder = inOrder(ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + verifyResourceMetadataRequested(LDS, LDS_RESOURCE); + + // Initial LDS response. + call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + + // Empty LDS response deletes the listener and fails RPCs. + call.sendResponse(LDS, Collections.emptyList(), VERSION_2, "0001"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + inOrder.verify(ldsResourceWatcher).onAmbientError(statusCaptor.capture()); + Status receivedStatus = statusCaptor.getValue(); + assertThat(receivedStatus.getCode()).isEqualTo(Status.Code.NOT_FOUND); + assertThat(receivedStatus.getDescription()).contains( + "Resource " + LDS_RESOURCE + " deleted from server"); + inOrder.verify(ldsResourceWatcher, never()).onResourceChanged(any()); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + /** + * Tests that fail_on_data_errors feature is ignored if the env var is not enabled, + * and the old behavior (dropping the resource) is used. + */ + @Test + public void ldsResourceDeleted_failOnDataErrorsIgnoredWithoutEnvVar() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, true); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of()) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + + InOrder inOrder = inOrder(ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + assertThat(ldsUpdateCaptor.getValue().hasValue()).isTrue(); + call.sendResponse(LDS, Collections.emptyList(), VERSION_2, "0001"); + + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isFalse(); + assertThat(statusOrUpdate.getStatus().getCode()).isEqualTo(Status.Code.NOT_FOUND); + } + @Test @SuppressWarnings("unchecked") public void multipleLdsWatchers() { String ldsResourceTwo = "bar.googleapis.com"; ResourceWatcher watcher1 = mock(ResourceWatcher.class); ResourceWatcher watcher2 = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, ldsResourceWatcher); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),ldsResourceTwo, watcher1); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),ldsResourceTwo, watcher2); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), ldsResourceTwo, watcher1); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), ldsResourceTwo, watcher2); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(LDS, ImmutableList.of(LDS_RESOURCE, ldsResourceTwo), "", "", NODE); // Both LDS resources were requested. @@ -1307,9 +1739,12 @@ public void multipleLdsWatchers() { verifySubscribedResourcesMetadataSizes(2, 0, 0, 0); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); - verify(watcher1).onResourceDoesNotExist(ldsResourceTwo); - verify(watcher2).onResourceDoesNotExist(ldsResourceTwo); + verify(ldsResourceWatcher).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() && statusOr.getStatus().getDescription().contains(LDS_RESOURCE))); + verify(watcher1).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() && statusOr.getStatus().getDescription().contains(ldsResourceTwo))); + verify(watcher2).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() && statusOr.getStatus().getDescription().contains(ldsResourceTwo))); verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); verifyResourceMetadataDoesNotExist(LDS, ldsResourceTwo); verifySubscribedResourcesMetadataSizes(2, 0, 0, 0); @@ -1317,16 +1752,22 @@ public void multipleLdsWatchers() { Any listenerTwo = Any.pack(mf.buildListenerWithApiListenerForRds(ldsResourceTwo, RDS_RESOURCE)); call.sendResponse(LDS, ImmutableList.of(testListenerVhosts, listenerTwo), VERSION_1, "0000"); // ResourceWatcher called with listenerVhosts. - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); // watcher1 called with listenerTwo. - verify(watcher1).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()).isNull(); + verify(watcher1, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerRds(statusOrUpdate.getValue()); + assertThat(statusOrUpdate.getValue().httpConnectionManager().virtualHosts()).isNull(); // watcher2 called with listenerTwo. - verify(watcher2).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()).isNull(); + verify(watcher2, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerRds(statusOrUpdate.getValue()); + assertThat(statusOrUpdate.getValue().httpConnectionManager().virtualHosts()).isNull(); // Metadata of both listeners is stored. verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(LDS, ldsResourceTwo, listenerTwo, VERSION_1, TIME_INCREMENT); @@ -1338,7 +1779,7 @@ public void rdsResourceNotFound() { DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); Any routeConfig = Any.pack(mf.buildRouteConfiguration("route-bar.googleapis.com", - mf.buildOpaqueVirtualHosts(2))); + mf.buildOpaqueVirtualHosts(2))); call.sendResponse(RDS, routeConfig, VERSION_1, "0000"); // Client sends an ACK RDS request. @@ -1348,7 +1789,8 @@ public void rdsResourceNotFound() { verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); // Server failed to return subscribed resource within expected time window. fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); + verify(rdsResourceWatcher).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(RDS_RESOURCE))); assertThat(fakeClock.getPendingTasks(RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(RDS, RDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); @@ -1389,7 +1831,7 @@ public void rdsResponseErrorHandling_someResourcesFailedUnpack() { verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); // The response is NACKed with the same error message. call.verifyRequestNack(RDS, RDS_RESOURCE, "", "0000", NODE, errors); - verify(rdsResourceWatcher).onChanged(any(RdsUpdate.class)); + verify(rdsResourceWatcher).onResourceChanged(any()); } @Test @@ -1397,6 +1839,7 @@ public void rdsResponseErrorHandling_nackWeightedSumZero() { DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); + String expectedErrorDetail = "Sum of cluster weights should be above 0"; io.envoyproxy.envoy.config.route.v3.RouteAction routeAction = io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() @@ -1430,25 +1873,29 @@ public void rdsResponseErrorHandling_nackWeightedSumZero() { "RDS response RouteConfiguration \'route-configuration.googleapis.com\' validation error: " + "RouteConfiguration contains invalid virtual host: Virtual host [do not care] " + "contains invalid route : Route [route-blade] contains invalid RouteAction: " - + "Sum of cluster weights should be above 0."); + + expectedErrorDetail); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); // The response is NACKed with the same error message. call.verifyRequestNack(RDS, RDS_RESOURCE, "", "0000", NODE, errors); - verify(rdsResourceWatcher, never()).onChanged(any(RdsUpdate.class)); + verify(rdsResourceWatcher).onResourceChanged(argThat( + statusOr -> !statusOr.hasValue() && statusOr.getStatus().getDescription() + .contains(expectedErrorDetail))); + verify(rdsResourceWatcher, never()).onResourceChanged(argThat(StatusOr::hasValue)); } /** * Tests a subscribed RDS resource transitioned to and from the invalid state. * - * @see - * A40-csds-support.md + * @see + * A40-csds-support.md */ @Test public void rdsResponseErrorHandling_subscribedResourceInvalid() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"A", rdsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"B", rdsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"C", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "A", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "B", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "C", rdsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(RDS, "A"); @@ -1467,6 +1914,8 @@ public void rdsResponseErrorHandling_subscribedResourceInvalid() { verifyResourceMetadataAcked(RDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(RDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(RDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + // Check metric data. + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), RDS.typeUrl()); call.verifyRequest(RDS, subscribedResourceNames, VERSION_1, "0000", NODE); // RDS -> {A, B}, version 2 @@ -1478,11 +1927,13 @@ public void rdsResponseErrorHandling_subscribedResourceInvalid() { // {A} -> ACK, version 2 // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), + RDS.typeUrl()); List errorsV2 = ImmutableList.of("RDS response RouteConfiguration 'B' validation error: "); verifyResourceMetadataAcked(RDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(RDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + VERSION_2, TIME_INCREMENT * 2, errorsV2, true); verifyResourceMetadataAcked(RDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); call.verifyRequestNack(RDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); @@ -1494,6 +1945,8 @@ public void rdsResponseErrorHandling_subscribedResourceInvalid() { call.sendResponse(RDS, resourcesV3.values().asList(), VERSION_3, "0002"); // {A} -> ACK, version 2 // {B, C} -> ACK, version 3 + verifyResourceValidInvalidCount(1, 2, 0, xdsServerInfo.target(), + RDS.typeUrl()); verifyResourceMetadataAcked(RDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataAcked(RDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); verifyResourceMetadataAcked(RDS, "C", resourcesV3.get("C"), VERSION_3, TIME_INCREMENT * 3); @@ -1509,8 +1962,10 @@ public void rdsResourceFound() { // Client sends an ACK RDS request. call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); - verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenRouteConfig(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); @@ -1524,8 +1979,10 @@ public void wrappedRdsResource() { // Client sends an ACK RDS request. call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); - verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenRouteConfig(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); @@ -1542,9 +1999,11 @@ public void cachedRdsResource_data() { call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, watcher); - verify(watcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenRouteConfig(statusOrUpdate.getValue()); call.verifyNoMoreRequest(); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); @@ -1556,11 +2015,15 @@ public void cachedRdsResource_absent() { DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); + verify(rdsResourceWatcher).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() && statusOr.getStatus().getDescription().contains(RDS_RESOURCE) + && statusOr.getStatus().getDescription().contains(RDS_RESOURCE))); // Add another watcher. ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, watcher); - verify(watcher).onResourceDoesNotExist(RDS_RESOURCE); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() && statusOr.getStatus().getDescription().contains(RDS_RESOURCE) + && statusOr.getStatus().getDescription().contains(RDS_RESOURCE))); call.verifyNoMoreRequest(); verifyResourceMetadataDoesNotExist(RDS, RDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); @@ -1575,8 +2038,10 @@ public void rdsResourceUpdated() { // Initial RDS response. call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); - verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenRouteConfig(statusOrUpdate.getValue()); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); // Updated RDS response. @@ -1586,18 +2051,49 @@ public void rdsResourceUpdated() { // Client sends an ACK RDS request. call.verifyRequest(RDS, RDS_RESOURCE, VERSION_2, "0001", NODE); - verify(rdsResourceWatcher, times(2)).onChanged(rdsUpdateCaptor.capture()); - assertThat(rdsUpdateCaptor.getValue().virtualHosts).hasSize(4); + verify(rdsResourceWatcher, times(2)).onResourceChanged(rdsUpdateCaptor.capture()); + statusOrUpdate = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + assertThat(statusOrUpdate.getValue().virtualHosts).hasSize(4); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, routeConfigUpdated, VERSION_2, TIME_INCREMENT * 2); - verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); + } + + @Test + public void rdsResourceInvalid() { + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "A", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "B", rdsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + assertThat(call).isNotNull(); + verifyResourceMetadataRequested(RDS, "A"); + verifyResourceMetadataRequested(RDS, "B"); + verifySubscribedResourcesMetadataSizes(0, 0, 2, 0); + + // RDS -> {A, B}, version 1 + // Failed to parse endpoint B + List vhostsV1 = mf.buildOpaqueVirtualHosts(1); + ImmutableMap resourcesV1 = ImmutableMap.of( + "A", Any.pack(mf.buildRouteConfiguration("A", vhostsV1)), + "B", Any.pack(mf.buildRouteConfigurationInvalid("B"))); + call.sendResponse(RDS, resourcesV1.values().asList(), VERSION_1, "0000"); + + // {A} -> ACK, version 1 + // {B} -> NACK, version 1, rejected version 1, rejected reason: Failed to parse B + List errorsV1 = + ImmutableList.of("RDS response RouteConfiguration 'B' validation error: "); + verifyResourceMetadataAcked(RDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataNacked(RDS, "B", null, "", 0, + VERSION_1, TIME_INCREMENT, errorsV1, false); + // Check metric data. + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), RDS.typeUrl()); + verifySubscribedResourcesMetadataSizes(0, 0, 2, 0); } @Test public void rdsResourceDeletedByLdsApiListener() { - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); verifyResourceMetadataRequested(LDS, LDS_RESOURCE); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); @@ -1605,15 +2101,19 @@ public void rdsResourceDeletedByLdsApiListener() { DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerRds(statusOrUpdate.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); verifySubscribedResourcesMetadataSizes(1, 0, 1, 0); call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); - verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr statusOrUpdate1 = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenRouteConfig(statusOrUpdate1.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT * 2); verifySubscribedResourcesMetadataSizes(1, 0, 1, 0); @@ -1623,8 +2123,10 @@ public void rdsResourceDeletedByLdsApiListener() { // Note that this must work the same despite the ignore_resource_deletion feature is on. // This happens because the Listener is getting replaced, and not deleted. call.sendResponse(LDS, testListenerVhosts, VERSION_2, "0001"); - verify(ldsResourceWatcher, times(2)).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyNoMoreInteractions(rdsResourceWatcher); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked( @@ -1656,11 +2158,13 @@ public void rdsResourcesDeletedByLdsTcpListener() { // referencing RDS_RESOURCE. DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.sendResponse(LDS, packedListener, VERSION_1, "0000"); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); - assertThat(ldsUpdateCaptor.getValue().listener().filterChains()).hasSize(1); + assertThat(statusOrUpdate.getValue().listener().filterChains()).hasSize(1); FilterChain parsedFilterChain = Iterables.getOnlyElement( - ldsUpdateCaptor.getValue().listener().filterChains()); + statusOrUpdate.getValue().listener().filterChains()); assertThat(parsedFilterChain.httpConnectionManager().rdsName()).isEqualTo(RDS_RESOURCE); verifyResourceMetadataAcked(LDS, LISTENER_RESOURCE, packedListener, VERSION_1, TIME_INCREMENT); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); @@ -1668,8 +2172,10 @@ public void rdsResourcesDeletedByLdsTcpListener() { // Simulates receiving the requested RDS resource. call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); - verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr statusOrUpdate1 = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenRouteConfig(statusOrUpdate1.getValue()); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT * 2); // Simulates receiving an updated version of the requested LDS resource as a TCP listener @@ -1687,12 +2193,15 @@ public void rdsResourcesDeletedByLdsTcpListener() { packedListener = Any.pack(mf.buildListenerWithFilterChain(LISTENER_RESOURCE, 7000, "0.0.0.0", filterChain)); call.sendResponse(LDS, packedListener, VERSION_2, "0001"); - verify(ldsResourceWatcher, times(2)).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().listener().filterChains()).hasSize(1); + verify(ldsResourceWatcher, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + assertThat(statusOrUpdate.getValue().listener().filterChains()).hasSize(1); parsedFilterChain = Iterables.getOnlyElement( - ldsUpdateCaptor.getValue().listener().filterChains()); + statusOrUpdate.getValue().listener().filterChains()); assertThat(parsedFilterChain.httpConnectionManager().virtualHosts()).hasSize(VHOST_SIZE); - verify(rdsResourceWatcher, never()).onResourceDoesNotExist(RDS_RESOURCE); + verify(rdsResourceWatcher, never()).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() && statusOr.getStatus().getDescription().equals(RDS_RESOURCE))); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked( LDS, LISTENER_RESOURCE, packedListener, VERSION_2, TIME_INCREMENT * 3); @@ -1705,10 +2214,10 @@ public void multipleRdsWatchers() { String rdsResourceTwo = "route-bar.googleapis.com"; ResourceWatcher watcher1 = mock(ResourceWatcher.class); ResourceWatcher watcher2 = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),rdsResourceTwo, watcher1); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),rdsResourceTwo, watcher2); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), rdsResourceTwo, watcher1); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), rdsResourceTwo, watcher2); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(RDS, Arrays.asList(RDS_RESOURCE, rdsResourceTwo), "", "", NODE); // Both RDS resources were requested. @@ -1717,16 +2226,25 @@ public void multipleRdsWatchers() { verifySubscribedResourcesMetadataSizes(0, 0, 2, 0); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); - verify(watcher1).onResourceDoesNotExist(rdsResourceTwo); - verify(watcher2).onResourceDoesNotExist(rdsResourceTwo); + verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.NOT_FOUND)); + verify(watcher1).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.NOT_FOUND)); + verify(watcher2).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.NOT_FOUND)); verifyResourceMetadataDoesNotExist(RDS, RDS_RESOURCE); verifyResourceMetadataDoesNotExist(RDS, rdsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 0, 2, 0); call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); - verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + ArgumentCaptor> rdsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(rdsResourceWatcher, times(2)).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr capturedUpdate1 = rdsUpdateCaptor.getAllValues().get(1); + assertThat(capturedUpdate1.hasValue()).isTrue(); + verifyGoldenRouteConfig(capturedUpdate1.getValue()); verifyNoMoreInteractions(watcher1, watcher2); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifyResourceMetadataDoesNotExist(RDS, rdsResourceTwo); @@ -1735,13 +2253,22 @@ public void multipleRdsWatchers() { Any routeConfigTwo = Any.pack(mf.buildRouteConfiguration(rdsResourceTwo, mf.buildOpaqueVirtualHosts(4))); call.sendResponse(RDS, routeConfigTwo, VERSION_2, "0002"); - verify(watcher1).onChanged(rdsUpdateCaptor.capture()); - assertThat(rdsUpdateCaptor.getValue().virtualHosts).hasSize(4); - verify(watcher2).onChanged(rdsUpdateCaptor.capture()); - assertThat(rdsUpdateCaptor.getValue().virtualHosts).hasSize(4); + ArgumentCaptor> watcher1Captor = + ArgumentCaptor.forClass(StatusOr.class); + verify(watcher1, times(2)).onResourceChanged(watcher1Captor.capture()); + StatusOr capturedUpdate2 = watcher1Captor.getAllValues().get(1); + assertThat(capturedUpdate2.hasValue()).isTrue(); + assertThat(capturedUpdate2.getValue().virtualHosts).hasSize(4); + ArgumentCaptor> watcher2Captor = + ArgumentCaptor.forClass(StatusOr.class); + verify(watcher2, times(2)).onResourceChanged(watcher2Captor.capture()); + StatusOr capturedUpdate3 = watcher2Captor.getAllValues().get(1); + assertThat(capturedUpdate3.hasValue()).isTrue(); + assertThat(capturedUpdate3.getValue().virtualHosts).hasSize(4); verifyNoMoreInteractions(rdsResourceWatcher); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); - verifyResourceMetadataAcked(RDS, rdsResourceTwo, routeConfigTwo, VERSION_2, TIME_INCREMENT * 2); + verifyResourceMetadataAcked(RDS, rdsResourceTwo, routeConfigTwo, VERSION_2, + TIME_INCREMENT * 2); verifySubscribedResourcesMetadataSizes(0, 0, 2, 0); } @@ -1764,7 +2291,8 @@ public void cdsResourceNotFound() { verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); // Server failed to return subscribed resource within expected time window. fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); + verify(cdsResourceWatcher).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(CDS_RESOURCE))); assertThat(fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); @@ -1806,21 +2334,22 @@ public void cdsResponseErrorHandling_someResourcesFailedUnpack() { verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); // The response is NACKed with the same error message. call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, errors); - verify(cdsResourceWatcher).onChanged(any(CdsUpdate.class)); + verify(cdsResourceWatcher).onResourceChanged(any()); } /** * Tests a subscribed CDS resource transitioned to and from the invalid state. * - * @see - * A40-csds-support.md + * @see + * A40-csds-support.md */ @Test public void cdsResponseErrorHandling_subscribedResourceInvalid() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"A", cdsResourceWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"B", cdsResourceWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"C", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "A", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "B", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "C", cdsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(CDS, "A"); @@ -1841,6 +2370,8 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { ))); call.sendResponse(CDS, resourcesV1.values().asList(), VERSION_1, "0000"); // {A, B, C} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), + CDS.typeUrl()); verifyResourceMetadataAcked(CDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(CDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); @@ -1857,10 +2388,12 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { // {A} -> ACK, version 2 // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> does not exist + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), + CDS.typeUrl()); List errorsV2 = ImmutableList.of("CDS response Cluster 'B' validation error: "); verifyResourceMetadataAcked(CDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + VERSION_2, TIME_INCREMENT * 2, errorsV2, true); if (!ignoreResourceDeletion()) { verifyResourceMetadataDoesNotExist(CDS, "C"); } else { @@ -1880,6 +2413,8 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { call.sendResponse(CDS, resourcesV3.values().asList(), VERSION_3, "0002"); // {A} -> does not exit // {B, C} -> ACK, version 3 + verifyResourceValidInvalidCount(1, 2, 0, xdsServerInfo.target(), + CDS.typeUrl()); if (!ignoreResourceDeletion()) { verifyResourceMetadataDoesNotExist(CDS, "A"); } else { @@ -1888,18 +2423,19 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { } verifyResourceMetadataAcked(CDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); verifyResourceMetadataAcked(CDS, "C", resourcesV3.get("C"), VERSION_3, TIME_INCREMENT * 3); + call.verifyRequest(CDS, subscribedResourceNames, VERSION_3, "0002", NODE); } @Test public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscription() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"A", cdsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"A.1", edsResourceWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"B", cdsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"B.1", edsResourceWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"C", cdsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"C.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "A", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "A.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "B", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "B.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "C", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "C.1", edsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(CDS, "A"); @@ -1923,6 +2459,8 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscripti ))); call.sendResponse(CDS, resourcesV1.values().asList(), VERSION_1, "0000"); // {A, B, C} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), + CDS.typeUrl()); verifyResourceMetadataAcked(CDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(CDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); @@ -1937,6 +2475,8 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscripti "C.1", Any.pack(mf.buildClusterLoadAssignment("C.1", endpointsV1, dropOverloads))); call.sendResponse(EDS, resourcesV11.values().asList(), VERSION_1, "0000"); // {A.1, B.1, C.1} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), + EDS.typeUrl()); verifyResourceMetadataAcked(EDS, "A.1", resourcesV11.get("A.1"), VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked(EDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked(EDS, "C.1", resourcesV11.get("C.1"), VERSION_1, TIME_INCREMENT * 2); @@ -1952,11 +2492,13 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscripti // {A} -> ACK, version 2 // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> does not exist + // Check metric data. + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), CDS.typeUrl()); List errorsV2 = ImmutableList.of("CDS response Cluster 'B' validation error: "); verifyResourceMetadataAcked(CDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 3); verifyResourceMetadataNacked( CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 3, - errorsV2); + errorsV2, true); if (!ignoreResourceDeletion()) { verifyResourceMetadataDoesNotExist(CDS, "C"); } else { @@ -1982,8 +2524,10 @@ public void cdsResourceFound() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); @@ -1998,8 +2542,10 @@ public void wrappedCdsResource() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); @@ -2019,8 +2565,10 @@ public void cdsResourceFound_leastRequestLbPolicy() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isNull(); @@ -2051,8 +2599,10 @@ public void cdsResourceFound_ringHashLbPolicy() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isNull(); @@ -2082,8 +2632,10 @@ public void cdsResponseWithAggregateCluster() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.AGGREGATE); LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); @@ -2096,6 +2648,23 @@ public void cdsResponseWithAggregateCluster() { verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); } + @Test + public void cdsResponseWithEmptyAggregateCluster() { + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + List candidates = Arrays.asList(); + Any clusterAggregate = + Any.pack(mf.buildAggregateCluster(CDS_RESOURCE, "round_robin", null, null, candidates)); + call.sendResponse(CDS, clusterAggregate, VERSION_1, "0000"); + + // Client sent an ACK CDS request. + String errorMsg = "CDS response Cluster 'cluster.googleapis.com' validation error: " + + "Cluster cluster.googleapis.com: aggregate ClusterConfig.clusters must not be empty"; + call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + verifyStatusWithNodeId(cdsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, errorMsg); + } + @Test public void cdsResponseWithCircuitBreakers() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, @@ -2107,8 +2676,10 @@ public void cdsResponseWithCircuitBreakers() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isNull(); @@ -2129,7 +2700,6 @@ public void cdsResponseWithCircuitBreakers() { * CDS response containing UpstreamTlsContext for a cluster. */ @Test - @SuppressWarnings("deprecation") public void cdsResponseWithUpstreamTlsContext() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); @@ -2142,7 +2712,7 @@ public void cdsResponseWithUpstreamTlsContext() { "envoy.transport_sockets.tls", null, null)); List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", - "dns-service-bar.googleapis.com", 443, "round_robin", null, null,false, null, null)), + "dns-service-bar.googleapis.com", 443, "round_robin", null, null, false, null, null)), clusterEds, Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null, null))); @@ -2151,11 +2721,13 @@ public void cdsResponseWithUpstreamTlsContext() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); verify(cdsResourceWatcher, times(1)) - .onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); - CommonTlsContext.CertificateProviderInstance certificateProviderInstance = + .onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); + CertificateProviderPluginInstance certificateProviderInstance = cdsUpdate.upstreamTlsContext().getCommonTlsContext().getCombinedValidationContext() - .getValidationContextCertificateProviderInstance(); + .getDefaultValidationContext().getCaCertificateProviderInstance(); assertThat(certificateProviderInstance.getInstanceName()).isEqualTo("cert-instance-name"); assertThat(certificateProviderInstance.getCertificateName()).isEqualTo("cert1"); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusterEds, VERSION_1, TIME_INCREMENT); @@ -2166,7 +2738,6 @@ public void cdsResponseWithUpstreamTlsContext() { * CDS response containing new UpstreamTlsContext for a cluster. */ @Test - @SuppressWarnings("deprecation") public void cdsResponseWithNewUpstreamTlsContext() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); @@ -2174,7 +2745,7 @@ public void cdsResponseWithNewUpstreamTlsContext() { // Management server sends back CDS response with UpstreamTlsContext. Any clusterEds = Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", - null, null,true, + null, null, true, mf.buildNewUpstreamTlsContext("cert-instance-name", "cert1"), "envoy.transport_sockets.tls", null, null)); List clusters = ImmutableList.of( @@ -2187,8 +2758,10 @@ public void cdsResponseWithNewUpstreamTlsContext() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher, times(1)).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher, times(1)).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); CertificateProviderPluginInstance certificateProviderInstance = cdsUpdate.upstreamTlsContext().getCommonTlsContext().getValidationContext() .getCaCertificateProviderInstance(); @@ -2215,19 +2788,19 @@ public void cdsResponseErrorHandling_badUpstreamTlsContext() { // The response NACKed with errors indicating indices of the failed resources. String errorMsg = "CDS response Cluster 'cluster.googleapis.com' validation error: " - + "Cluster cluster.googleapis.com: malformed UpstreamTlsContext: " - + "io.grpc.xds.client.XdsResourceType$ResourceInvalidException: " - + "ca_certificate_provider_instance is required in upstream-tls-context"; + + "Cluster cluster.googleapis.com: malformed UpstreamTlsContext: " + + "io.grpc.xds.client.XdsResourceType$ResourceInvalidException: " + + "ca_certificate_provider_instance or system_root_certs is required in " + + "upstream-tls-context"; call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); - verify(cdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + verifyStatusWithNodeId(cdsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, errorMsg); } /** * CDS response containing OutlierDetection for a cluster. */ @Test - @SuppressWarnings("deprecation") public void cdsResponseWithOutlierDetection() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); @@ -2254,7 +2827,7 @@ public void cdsResponseWithOutlierDetection() { "envoy.transport_sockets.tls", null, outlierDetectionXds)); List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", - "dns-service-bar.googleapis.com", 443, "round_robin", null, null,false, null, null)), + "dns-service-bar.googleapis.com", 443, "round_robin", null, null, false, null, null)), clusterEds, Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null, outlierDetectionXds))); @@ -2262,8 +2835,10 @@ public void cdsResponseWithOutlierDetection() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher, times(1)).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher, times(1)).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); // The outlier detection config in CdsUpdate should match what we get from xDS. EnvoyServerProtoData.OutlierDetection outlierDetection = cdsUpdate.outlierDetection(); @@ -2296,7 +2871,6 @@ public void cdsResponseWithOutlierDetection() { * CDS response containing OutlierDetection for a cluster. */ @Test - @SuppressWarnings("deprecation") public void cdsResponseWithInvalidOutlierDetectionNacks() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, @@ -2313,7 +2887,7 @@ public void cdsResponseWithInvalidOutlierDetectionNacks() { "envoy.transport_sockets.tls", null, outlierDetectionXds)); List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", - "dns-service-bar.googleapis.com", 443, "round_robin", null, null,false, null, null)), + "dns-service-bar.googleapis.com", 443, "round_robin", null, null, false, null, null)), clusterEds, Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null, outlierDetectionXds))); @@ -2324,8 +2898,8 @@ public void cdsResponseWithInvalidOutlierDetectionNacks() { + "io.grpc.xds.client.XdsResourceType$ResourceInvalidException: outlier_detection " + "max_ejection_percent is > 100"; call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); - verify(cdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + verifyStatusWithNodeId(cdsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, errorMsg); } @Test(expected = ResourceInvalidException.class) @@ -2419,8 +2993,8 @@ public void cdsResponseErrorHandling_badTransportSocketName() { String errorMsg = "CDS response Cluster 'cluster.googleapis.com' validation error: " + "transport-socket with name envoy.transport_sockets.bad not supported."; call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); - verify(cdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + verifyStatusWithNodeId(cdsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, errorMsg); } @Test @@ -2433,8 +3007,7 @@ public void cdsResponseErrorHandling_xdstpWithoutEdsConfig() { )); final Any okClusterRoundRobin = Any.pack(mf.buildEdsCluster(cdsResourceName, "eds-service-bar.googleapis.com", - "round_robin", null,null, false, null, "envoy.transport_sockets.tls", null, null)); - + "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null, null)); DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), cdsResourceName, cdsResourceWatcher); @@ -2463,8 +3036,10 @@ public void cachedCdsResource_data() { ResourceWatcher watcher = mock(ResourceWatcher.class); xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, watcher); - verify(watcher).onChanged(cdsUpdateCaptor.capture()); - verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verify(watcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); call.verifyNoMoreRequest(); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); @@ -2478,10 +3053,12 @@ public void cachedCdsResource_absent() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); + verify(cdsResourceWatcher).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(CDS_RESOURCE))); ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, watcher); - verify(watcher).onResourceDoesNotExist(CDS_RESOURCE); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(CDS_RESOURCE))); call.verifyNoMoreRequest(); verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); @@ -2501,8 +3078,10 @@ public void cdsResourceUpdated() { null, null, false, null, null)); call.sendResponse(CDS, clusterDns, VERSION_1, "0000"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.LOGICAL_DNS); assertThat(cdsUpdate.dnsHostName()).isEqualTo(dnsHostAddr + ":" + dnsHostPort); @@ -2524,8 +3103,10 @@ public void cdsResourceUpdated() { )); call.sendResponse(CDS, clusterEds, VERSION_2, "0001"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_2, "0001", NODE); - verify(cdsResourceWatcher, times(2)).onChanged(cdsUpdateCaptor.capture()); - cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher, times(2)).onResourceChanged(cdsUpdateCaptor.capture()); + statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + cdsUpdate = statusOrUpdate.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isEqualTo(edsService); @@ -2566,27 +3147,27 @@ public void cdsResourceUpdatedWithDuplicate() { // Configure with round robin, the update should be sent to the watcher. call.sendResponse(CDS, roundRobinConfig, VERSION_2, "0001"); - verify(cdsResourceWatcher, times(1)).onChanged(isA(CdsUpdate.class)); + verify(cdsResourceWatcher, times(1)).onResourceChanged(argThat(StatusOr::hasValue)); // Second update is identical, watcher should not get an additional update. call.sendResponse(CDS, roundRobinConfig, VERSION_2, "0002"); - verify(cdsResourceWatcher, times(1)).onChanged(isA(CdsUpdate.class)); + verify(cdsResourceWatcher, times(1)).onResourceChanged(any()); // Now we switch to ring hash so the watcher should be notified. call.sendResponse(CDS, ringHashConfig, VERSION_2, "0003"); - verify(cdsResourceWatcher, times(2)).onChanged(isA(CdsUpdate.class)); + verify(cdsResourceWatcher, times(2)).onResourceChanged(argThat(StatusOr::hasValue)); // Second update to ring hash should not result in watcher being notified. call.sendResponse(CDS, ringHashConfig, VERSION_2, "0004"); - verify(cdsResourceWatcher, times(2)).onChanged(isA(CdsUpdate.class)); + verify(cdsResourceWatcher, times(2)).onResourceChanged(any()); // Now we switch to least request so the watcher should be notified. call.sendResponse(CDS, leastRequestConfig, VERSION_2, "0005"); - verify(cdsResourceWatcher, times(3)).onChanged(isA(CdsUpdate.class)); + verify(cdsResourceWatcher, times(3)).onResourceChanged(argThat(StatusOr::hasValue)); // Second update to least request should not result in watcher being notified. call.sendResponse(CDS, leastRequestConfig, VERSION_2, "0006"); - verify(cdsResourceWatcher, times(3)).onChanged(isA(CdsUpdate.class)); + verify(cdsResourceWatcher, times(3)).onResourceChanged(any()); } @Test @@ -2600,8 +3181,10 @@ public void cdsResourceDeleted() { // Initial CDS response. call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); @@ -2609,7 +3192,8 @@ public void cdsResourceDeleted() { // Empty CDS response deletes the cluster. call.sendResponse(CDS, Collections.emptyList(), VERSION_2, "0001"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_2, "0001", NODE); - verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); + verify(cdsResourceWatcher).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(CDS_RESOURCE))); verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); } @@ -2617,7 +3201,7 @@ public void cdsResourceDeleted() { /** * When ignore_resource_deletion server feature is on, xDS client should keep the deleted cluster * on empty response, and resume the normal work when CDS contains the cluster again. - * */ + */ @Test public void cdsResourceDeleted_ignoreResourceDeletion() { Assume.assumeTrue(ignoreResourceDeletion()); @@ -2629,8 +3213,10 @@ public void cdsResourceDeleted_ignoreResourceDeletion() { // Initial CDS response. call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); @@ -2644,28 +3230,253 @@ public void cdsResourceDeleted_ignoreResourceDeletion() { TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); // onResourceDoesNotExist must not be called. - verify(ldsResourceWatcher, never()).onResourceDoesNotExist(CDS_RESOURCE); + verify(ldsResourceWatcher, never()).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(CDS_RESOURCE))); // Next update is correct, and contains the cluster again. call.sendResponse(CDS, testClusterRoundRobin, VERSION_3, "0003"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_3, "0003", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_3, TIME_INCREMENT * 3); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); verifyNoMoreInteractions(ldsResourceWatcher); } + /** + * When fail_on_data_errors server feature is on, xDS client should delete the cached cluster + * and fail RPCs when CDS resource is deleted. + */ + @Test + public void cdsResourceDeleted_failOnDataErrors_true() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, true); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of()) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + verifyResourceMetadataRequested(CDS, CDS_RESOURCE); + + // Initial CDS response. + call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, + TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + + // Empty CDS response deletes the cluster and fails RPCs. + call.sendResponse(CDS, Collections.emptyList(), VERSION_2, "0001"); + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_2, "0001", NODE); + verify(cdsResourceWatcher).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(CDS_RESOURCE))); + verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + /** + * When fail_on_data_errors server feature is on, xDS client should delete the cached cluster + * and fail RPCs when CDS resource is deleted. + */ + @Test + public void cdsResourceDeleted_failOnDataErrors_false() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + // Set failOnDataErrors to false for this test case. + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, false); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of()) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + + InOrder inOrder = inOrder(cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + verifyResourceMetadataRequested(CDS, CDS_RESOURCE); + + // Initial CDS response. + call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + inOrder.verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, + TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + + // Empty CDS response should trigger an ambient error. + call.sendResponse(CDS, Collections.emptyList(), VERSION_2, "0001"); + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_2, "0001", NODE); + + // Verify that onAmbientError() is called. + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + inOrder.verify(cdsResourceWatcher).onAmbientError(statusCaptor.capture()); + Status receivedStatus = statusCaptor.getValue(); + assertThat(receivedStatus.getCode()).isEqualTo(Status.Code.NOT_FOUND); + assertThat(receivedStatus.getDescription()).contains( + "Resource " + CDS_RESOURCE + " deleted from server"); + + // Verify that onResourceChanged() is NOT called again. + inOrder.verify(cdsResourceWatcher, never()).onResourceChanged(any()); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + /** + * Tests that a NACKed LDS resource update drops the cached resource when fail_on_data_errors + * is enabled. + */ + @Test + public void ldsResourceNacked_withFailOnDataErrors_dropsResource() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, true); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + + InOrder inOrder = inOrder(ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr initialUpdate = ldsUpdateCaptor.getValue(); + assertThat(initialUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(initialUpdate.getValue()); + Message invalidListener = mf.buildListenerWithApiListenerInvalid(LDS_RESOURCE); + call.sendResponse(LDS, Collections.singletonList(Any.pack(invalidListener)), VERSION_2, "0001"); + String expectedError = "LDS response Listener '" + LDS_RESOURCE + "' validation error"; + call.verifyRequestNack(LDS, LDS_RESOURCE, VERSION_1, "0001", NODE, + Collections.singletonList(expectedError)); + + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr finalUpdate = ldsUpdateCaptor.getValue(); + assertThat(finalUpdate.hasValue()).isFalse(); + assertThat(finalUpdate.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(finalUpdate.getStatus().getDescription()).contains(expectedError); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + /** + * Tests that a NACKed LDS resource update is treated as an ambient error when + * fail_on_data_errors is disabled. + */ + @Test + public void ldsResourceNacked_withFailOnDataErrorsDisabled_isAmbientError() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, false); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + InOrder inOrder = inOrder(ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + + call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); + inOrder.verify(ldsResourceWatcher).onResourceChanged(any()); + Message invalidListener = mf.buildListenerWithApiListenerInvalid(LDS_RESOURCE); + call.sendResponse(LDS, Collections.singletonList(Any.pack(invalidListener)), VERSION_2, "0001"); + + String expectedError = "LDS response Listener '" + LDS_RESOURCE + "' validation error"; + call.verifyRequestNack(LDS, LDS_RESOURCE, VERSION_1, "0001", NODE, + Collections.singletonList(expectedError)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + inOrder.verify(ldsResourceWatcher).onAmbientError(statusCaptor.capture()); + Status receivedStatus = statusCaptor.getValue(); + assertThat(receivedStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(receivedStatus.getDescription()).contains(expectedError); + inOrder.verify(ldsResourceWatcher, never()).onResourceChanged(any()); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + @Test @SuppressWarnings("unchecked") public void multipleCdsWatchers() { String cdsResourceTwo = "cluster-bar.googleapis.com"; ResourceWatcher watcher1 = mock(ResourceWatcher.class); ResourceWatcher watcher2 = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, cdsResourceWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),cdsResourceTwo, watcher1); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),cdsResourceTwo, watcher2); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), cdsResourceTwo, watcher1); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), cdsResourceTwo, watcher2); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(CDS, Arrays.asList(CDS_RESOURCE, cdsResourceTwo), "", "", NODE); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); @@ -2673,9 +3484,12 @@ public void multipleCdsWatchers() { verifySubscribedResourcesMetadataSizes(0, 2, 0, 0); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); - verify(watcher1).onResourceDoesNotExist(cdsResourceTwo); - verify(watcher2).onResourceDoesNotExist(cdsResourceTwo); + verify(cdsResourceWatcher).onResourceChanged(argThat(statusOr -> + statusOr.getStatus().getDescription().contains(CDS_RESOURCE))); + verify(watcher1).onResourceChanged(argThat(statusOr -> + statusOr.getStatus().getDescription().contains(cdsResourceTwo))); + verify(watcher2).onResourceChanged(argThat(statusOr -> + statusOr.getStatus().getDescription().contains(cdsResourceTwo))); verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); verifyResourceMetadataDoesNotExist(CDS, cdsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 2, 0, 0); @@ -2689,45 +3503,54 @@ public void multipleCdsWatchers() { Any.pack(mf.buildEdsCluster(cdsResourceTwo, edsService, "round_robin", null, null, true, null, "envoy.transport_sockets.tls", null, null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); - assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); - assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.LOGICAL_DNS); - assertThat(cdsUpdate.dnsHostName()).isEqualTo(dnsHostAddr + ":" + dnsHostPort); - LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + ArgumentCaptor> cdsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(cdsResourceWatcher, times(2)).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr capturedUpdate1 = cdsUpdateCaptor.getAllValues().get(1); + assertThat(capturedUpdate1.hasValue()).isTrue(); + CdsUpdate cdsUpdate1 = capturedUpdate1.getValue(); + assertThat(cdsUpdate1.clusterName()).isEqualTo(CDS_RESOURCE); + assertThat(cdsUpdate1.clusterType()).isEqualTo(ClusterType.LOGICAL_DNS); + assertThat(cdsUpdate1.dnsHostName()).isEqualTo(dnsHostAddr + ":" + dnsHostPort); + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate1.lbPolicyConfig()); assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); - assertThat(cdsUpdate.lrsServerInfo()).isNull(); - assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); - assertThat(cdsUpdate.upstreamTlsContext()).isNull(); - verify(watcher1).onChanged(cdsUpdateCaptor.capture()); - cdsUpdate = cdsUpdateCaptor.getValue(); - assertThat(cdsUpdate.clusterName()).isEqualTo(cdsResourceTwo); - assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); - assertThat(cdsUpdate.edsServiceName()).isEqualTo(edsService); - lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(cdsUpdate1.lrsServerInfo()).isNull(); + assertThat(cdsUpdate1.maxConcurrentRequests()).isNull(); + assertThat(cdsUpdate1.upstreamTlsContext()).isNull(); + ArgumentCaptor> watcher1Captor = ArgumentCaptor.forClass(StatusOr.class); + verify(watcher1, times(2)).onResourceChanged(watcher1Captor.capture()); + StatusOr capturedUpdate2 = watcher1Captor.getAllValues().get(1); + assertThat(capturedUpdate2.hasValue()).isTrue(); + CdsUpdate cdsUpdate2 = capturedUpdate2.getValue(); + assertThat(cdsUpdate2.clusterName()).isEqualTo(cdsResourceTwo); + assertThat(cdsUpdate2.clusterType()).isEqualTo(ClusterType.EDS); + assertThat(cdsUpdate2.edsServiceName()).isEqualTo(edsService); + lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate2.lbPolicyConfig()); assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); - assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(xdsServerInfo); - assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); - assertThat(cdsUpdate.upstreamTlsContext()).isNull(); - verify(watcher2).onChanged(cdsUpdateCaptor.capture()); - cdsUpdate = cdsUpdateCaptor.getValue(); - assertThat(cdsUpdate.clusterName()).isEqualTo(cdsResourceTwo); - assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); - assertThat(cdsUpdate.edsServiceName()).isEqualTo(edsService); - lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(cdsUpdate2.lrsServerInfo()).isEqualTo(xdsServerInfo); + assertThat(cdsUpdate2.maxConcurrentRequests()).isNull(); + assertThat(cdsUpdate2.upstreamTlsContext()).isNull(); + ArgumentCaptor> watcher2Captor = ArgumentCaptor.forClass(StatusOr.class); + verify(watcher2, times(2)).onResourceChanged(watcher2Captor.capture()); + StatusOr capturedUpdate3 = watcher2Captor.getAllValues().get(1); + assertThat(capturedUpdate3.hasValue()).isTrue(); + CdsUpdate cdsUpdate3 = capturedUpdate3.getValue(); + assertThat(cdsUpdate3.clusterName()).isEqualTo(cdsResourceTwo); + assertThat(cdsUpdate3.clusterType()).isEqualTo(ClusterType.EDS); + assertThat(cdsUpdate3.edsServiceName()).isEqualTo(edsService); + lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate3.lbPolicyConfig()); assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); - assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(xdsServerInfo); - assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); - assertThat(cdsUpdate.upstreamTlsContext()).isNull(); + assertThat(cdsUpdate3.lrsServerInfo()).isEqualTo(xdsServerInfo); + assertThat(cdsUpdate3.maxConcurrentRequests()).isNull(); + assertThat(cdsUpdate3.upstreamTlsContext()).isNull(); // Metadata of both clusters is stored. verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusters.get(0), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(CDS, cdsResourceTwo, clusters.get(1), VERSION_1, TIME_INCREMENT); @@ -2751,7 +3574,8 @@ public void edsResourceNotFound() { verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); // Server failed to return subscribed resource within expected time window. fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); + verify(edsResourceWatcher).onResourceChanged(argThat(statusOr -> + statusOr.getStatus().getDescription().contains(EDS_RESOURCE))); assertThat(fakeClock.getPendingTasks(EDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(EDS, EDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); @@ -2771,21 +3595,23 @@ public void edsCleanupNonceAfterUnsubscription() { List dropOverloads = ImmutableList.of(); List endpointsV1 = ImmutableList.of(lbEndpointHealthy); ImmutableMap resourcesV1 = ImmutableMap.of( - "A.1", Any.pack(mf.buildClusterLoadAssignment("A.1", endpointsV1, dropOverloads))); + "A.1", Any.pack(mf.buildClusterLoadAssignment("A.1", endpointsV1, dropOverloads))); call.sendResponse(EDS, resourcesV1.values().asList(), VERSION_1, "0000"); // {A.1} -> ACK, version 1 call.verifyRequest(EDS, "A.1", VERSION_1, "0000", NODE); - verify(edsResourceWatcher, times(1)).onChanged(any()); + verify(edsResourceWatcher, times(1)).onResourceChanged(any()); // trigger an EDS resource unsubscription. xdsClient.cancelXdsResourceWatch(XdsEndpointResource.getInstance(), "A.1", edsResourceWatcher); verifySubscribedResourcesMetadataSizes(0, 0, 0, 0); call.verifyRequest(EDS, Arrays.asList(), VERSION_1, "0000", NODE); + // The control plane can send an updated response for the empty subscription list, with a new + // nonce. + call.sendResponse(EDS, Arrays.asList(), VERSION_1, "0001"); - // When re-subscribing, the version and nonce were properly forgotten, so the request is the - // same as the initial request + // When re-subscribing, the version was forgotten but not the nonce xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "A.1", edsResourceWatcher); - call.verifyRequest(EDS, "A.1", "", "", NODE, Mockito.timeout(2000).times(2)); + call.verifyRequest(EDS, "A.1", "", "0001", NODE, Mockito.timeout(2000)); } @Test @@ -2824,29 +3650,31 @@ public void edsResponseErrorHandling_someResourcesFailedUnpack() { verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); // The response is NACKed with the same error message. call.verifyRequestNack(EDS, EDS_RESOURCE, "", "0000", NODE, errors); - verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - EdsUpdate edsUpdate = edsUpdateCaptor.getValue(); + verify(edsResourceWatcher).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + EdsUpdate edsUpdate = statusOrUpdate.getValue(); assertThat(edsUpdate.clusterName).isEqualTo(EDS_RESOURCE); } /** * Tests a subscribed EDS resource transitioned to and from the invalid state. * - * @see - * A40-csds-support.md + * @see + * A40-csds-support.md */ @Test public void edsResponseErrorHandling_subscribedResourceInvalid() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"A", edsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"B", edsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"C", edsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "A", edsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "B", edsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "C", edsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(EDS, "A"); verifyResourceMetadataRequested(EDS, "B"); verifyResourceMetadataRequested(EDS, "C"); - verifySubscribedResourcesMetadataSizes(0, 0, 0, 3); // EDS -> {A, B, C}, version 1 List dropOverloads = ImmutableList.of(mf.buildDropOverload("lb", 200)); @@ -2857,6 +3685,7 @@ public void edsResponseErrorHandling_subscribedResourceInvalid() { "C", Any.pack(mf.buildClusterLoadAssignment("C", endpointsV1, dropOverloads))); call.sendResponse(EDS, resourcesV1.values().asList(), VERSION_1, "0000"); // {A, B, C} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), EDS.typeUrl()); verifyResourceMetadataAcked(EDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(EDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(EDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); @@ -2872,11 +3701,13 @@ public void edsResponseErrorHandling_subscribedResourceInvalid() { // {A} -> ACK, version 2 // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> ACK, version 1 + // Check metric data. + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), EDS.typeUrl()); List errorsV2 = ImmutableList.of("EDS response ClusterLoadAssignment 'B' validation error: "); verifyResourceMetadataAcked(EDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(EDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + VERSION_2, TIME_INCREMENT * 2, errorsV2, true); verifyResourceMetadataAcked(EDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); call.verifyRequestNack(EDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); @@ -2889,6 +3720,8 @@ public void edsResponseErrorHandling_subscribedResourceInvalid() { call.sendResponse(EDS, resourcesV3.values().asList(), VERSION_3, "0002"); // {A} -> ACK, version 2 // {B, C} -> ACK, version 3 + // Check metric data. + verifyResourceValidInvalidCount(1, 2, 0, xdsServerInfo.target(), EDS.typeUrl()); verifyResourceMetadataAcked(EDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataAcked(EDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); verifyResourceMetadataAcked(EDS, "C", resourcesV3.get("C"), VERSION_3, TIME_INCREMENT * 3); @@ -2904,8 +3737,10 @@ public void edsResourceFound() { // Client sent an ACK EDS request. call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); - verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - validateGoldenClusterLoadAssignment(edsUpdateCaptor.getValue()); + verify(edsResourceWatcher).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + validateGoldenClusterLoadAssignment(statusOrUpdate.getValue()); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); @@ -2919,8 +3754,10 @@ public void wrappedEdsResourceFound() { // Client sent an ACK EDS request. call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); - verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - validateGoldenClusterLoadAssignment(edsUpdateCaptor.getValue()); + verify(edsResourceWatcher).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + validateGoldenClusterLoadAssignment(statusOrUpdate.getValue()); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); @@ -2937,9 +3774,11 @@ public void cachedEdsResource_data() { call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); // Add another watcher. ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, watcher); - verify(watcher).onChanged(edsUpdateCaptor.capture()); - validateGoldenClusterLoadAssignment(edsUpdateCaptor.getValue()); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + validateGoldenClusterLoadAssignment(statusOrUpdate.getValue()); call.verifyNoMoreRequest(); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); @@ -2952,10 +3791,12 @@ public void cachedEdsResource_absent() { DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); + verify(edsResourceWatcher).onResourceChanged(argThat(statusOr -> + statusOr.getStatus().getDescription().contains(EDS_RESOURCE))); ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, watcher); - verify(watcher).onResourceDoesNotExist(EDS_RESOURCE); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(argThat(statusOr -> + statusOr.getStatus().getDescription().contains(EDS_RESOURCE))); call.verifyNoMoreRequest(); verifyResourceMetadataDoesNotExist(EDS, EDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); @@ -2986,7 +3827,7 @@ public void flowControlAbsent() throws Exception { fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); assertThat(fakeWatchClock.getPendingTasks().size()).isEqualTo(2); CyclicBarrier barrier = new CyclicBarrier(2); - doAnswer(blockUpdate(barrier)).when(cdsResourceWatcher).onChanged(any(CdsUpdate.class)); + doAnswer(blockUpdate(barrier)).when(cdsResourceWatcher).onResourceChanged(any()); CountDownLatch latch = new CountDownLatch(1); new Thread(() -> { @@ -3008,16 +3849,120 @@ public void flowControlAbsent() throws Exception { verifyResourceMetadataAcked( CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); barrier.await(); - verify(cdsResourceWatcher, atLeastOnce()).onChanged(any()); + verify(cdsResourceWatcher, atLeastOnce()).onResourceChanged(any()); String errorMsg = "CDS response Cluster 'cluster.googleapis.com2' validation error: " + "Cluster cluster.googleapis.com2: unspecified cluster discovery type"; call.verifyRequestNack(CDS, Arrays.asList(CDS_RESOURCE, anotherCdsResource), VERSION_1, "0001", NODE, Arrays.asList(errorMsg)); barrier.await(); latch.await(10, TimeUnit.SECONDS); - verify(cdsResourceWatcher, times(2)).onChanged(any()); - verify(anotherWatcher).onResourceDoesNotExist(eq(anotherCdsResource)); - verify(anotherWatcher).onError(any()); + verify(cdsResourceWatcher, times(2)).onResourceChanged(any()); + verify(anotherWatcher, times(2)).onResourceChanged( + argThat(statusOr -> statusOr.getStatus().getDescription().contains(anotherCdsResource))); + } + + @Test + public void resourceTimerIsTransientError_schedulesExtendedTimeout() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + ServerInfo serverInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, + false, true, true, false); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(serverInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of()) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + @SuppressWarnings("unchecked") + ResourceWatcher watcher = mock(ResourceWatcher.class); + String resourceName = "cluster.googleapis.com"; + + xdsClient.watchXdsResource( + XdsClusterResource.getInstance(), + resourceName, + watcher, + fakeClock.getScheduledExecutorService()); + + ScheduledTask task = Iterables.getOnlyElement( + fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)); + assertThat(task.getDelay(TimeUnit.SECONDS)) + .isEqualTo(XdsClientImpl.EXTENDED_RESOURCE_FETCH_TIMEOUT_SEC); + fakeClock.runDueTasks(); + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + @Test + public void resourceTimerIsTransientError_callsOnErrorUnavailable() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, ignoreResourceDeletion(), + true, true, false); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "authority.xds.com", + AuthorityInfo.create( + "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_CUSTOM_AUTHORITY, CHANNEL_CREDENTIALS))), + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of("cert-instance-name", + CertificateProviderInfo.create("file-watcher", ImmutableMap.of()))) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + String timeoutResource = CDS_RESOURCE + "_timeout"; + @SuppressWarnings("unchecked") + ResourceWatcher timeoutWatcher = mock(ResourceWatcher.class); + + xdsClient.watchXdsResource( + XdsClusterResource.getInstance(), + timeoutResource, + timeoutWatcher, + fakeClock.getScheduledExecutorService()); + + assertThat(resourceDiscoveryCalls).hasSize(1); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + call.verifyRequest(CDS, ImmutableList.of(timeoutResource), "", "", NODE); + fakeClock.forwardTime(XdsClientImpl.EXTENDED_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.runDueTasks(); + @SuppressWarnings("unchecked") + ArgumentCaptor> statusOrCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(timeoutWatcher).onResourceChanged(statusOrCaptor.capture()); + StatusOr statusOr = statusOrCaptor.getValue(); + Status error = statusOr.getStatus(); + assertThat(error.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(error.getDescription()).isEqualTo( + "Timed out waiting for resource " + timeoutResource + " from xDS server"); + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; } private Answer blockUpdate(CyclicBarrier barrier) { @@ -3048,7 +3993,7 @@ public void simpleFlowControl() throws Exception { // Updated EDS response. Any updatedClusterLoadAssignment = Any.pack(mf.buildClusterLoadAssignment(EDS_RESOURCE, ImmutableList.of(mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", - mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3), 2, 0)), + mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3, "endpoint-host-name"), 2, 0)), ImmutableList.of())); call.sendResponse(EDS, updatedClusterLoadAssignment, VERSION_2, "0001"); // message not processed due to flow control @@ -3056,7 +4001,7 @@ public void simpleFlowControl() throws Exception { assertThat(call.isReady()).isFalse(); CyclicBarrier barrier = new CyclicBarrier(2); - doAnswer(blockUpdate(barrier)).when(edsResourceWatcher).onChanged(any(EdsUpdate.class)); + doAnswer(blockUpdate(barrier)).when(edsResourceWatcher).onResourceChanged(any()); CountDownLatch latch = new CountDownLatch(1); new Thread(() -> { @@ -3071,12 +4016,14 @@ public void simpleFlowControl() throws Exception { verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); barrier.await(); - verify(edsResourceWatcher, atLeastOnce()).onChanged(edsUpdateCaptor.capture()); - EdsUpdate edsUpdate = edsUpdateCaptor.getAllValues().get(0); + verify(edsResourceWatcher, atLeastOnce()).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getAllValues().get(0); + assertThat(statusOrUpdate.hasValue()).isTrue(); + EdsUpdate edsUpdate = statusOrUpdate.getValue(); validateGoldenClusterLoadAssignment(edsUpdate); barrier.await(); latch.await(10, TimeUnit.SECONDS); - verify(edsResourceWatcher, times(2)).onChanged(any()); + verify(edsResourceWatcher, times(2)).onResourceChanged(any()); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, updatedClusterLoadAssignment, VERSION_2, TIME_INCREMENT * 2); } @@ -3088,7 +4035,7 @@ public void flowControlUnknownType() { call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); call.sendResponse(EDS, testClusterLoadAssignment, VERSION_1, "0000"); call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); - verify(edsResourceWatcher).onChanged(any()); + verify(edsResourceWatcher).onResourceChanged(any()); } @Test @@ -3100,8 +4047,10 @@ public void edsResourceUpdated() { // Initial EDS response. call.sendResponse(EDS, testClusterLoadAssignment, VERSION_1, "0000"); call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); - verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - EdsUpdate edsUpdate = edsUpdateCaptor.getValue(); + verify(edsResourceWatcher).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + EdsUpdate edsUpdate = statusOrUpdate.getValue(); validateGoldenClusterLoadAssignment(edsUpdate); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); @@ -3109,12 +4058,14 @@ public void edsResourceUpdated() { // Updated EDS response. Any updatedClusterLoadAssignment = Any.pack(mf.buildClusterLoadAssignment(EDS_RESOURCE, ImmutableList.of(mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", - mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3), 2, 0)), + mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3, "endpoint-host-name"), 2, 0)), ImmutableList.of())); call.sendResponse(EDS, updatedClusterLoadAssignment, VERSION_2, "0001"); - verify(edsResourceWatcher, times(2)).onChanged(edsUpdateCaptor.capture()); - edsUpdate = edsUpdateCaptor.getValue(); + verify(edsResourceWatcher, times(2)).onResourceChanged(edsUpdateCaptor.capture()); + statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + edsUpdate = statusOrUpdate.getValue(); assertThat(edsUpdate.clusterName).isEqualTo(EDS_RESOURCE); assertThat(edsUpdate.dropPolicies).isEmpty(); assertThat(edsUpdate.localityLbEndpointsMap) @@ -3122,7 +4073,9 @@ public void edsResourceUpdated() { Locality.create("region2", "zone2", "subzone2"), LocalityLbEndpoints.create( ImmutableList.of( - LbEndpoint.create("172.44.2.2", 8000, 3, true)), 2, 0)); + LbEndpoint.create("172.44.2.2", 8000, 3, + true, "endpoint-host-name", ImmutableMap.of())), + 2, 0, ImmutableMap.of())); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, updatedClusterLoadAssignment, VERSION_2, TIME_INCREMENT * 2); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); @@ -3138,9 +4091,9 @@ public void edsDuplicateLocalityInTheSamePriority() { Any updatedClusterLoadAssignment = Any.pack(mf.buildClusterLoadAssignment(EDS_RESOURCE, ImmutableList.of( mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", - mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3), 2, 1), + mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3, "endpoint-host-name"), 2, 1), mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", - mf.buildLbEndpoint("172.44.2.3", 8080, "healthy", 10), 2, 1) + mf.buildLbEndpoint("172.44.2.3", 8080, "healthy", 10, "endpoint-host-name"), 2, 1) ), ImmutableList.of())); call.sendResponse(EDS, updatedClusterLoadAssignment, "0", "0001"); @@ -3150,6 +4103,12 @@ public void edsDuplicateLocalityInTheSamePriority() { + "locality:Locality{region=region2, zone=zone2, subZone=subzone2} for priority:1"; call.verifyRequestNack(EDS, EDS_RESOURCE, "", "0001", NODE, ImmutableList.of( errorMsg)); + @SuppressWarnings("unchecked") + ArgumentCaptor> captor = ArgumentCaptor.forClass(StatusOr.class); + verify(edsResourceWatcher).onResourceChanged(captor.capture()); + StatusOr statusOrUpdate = captor.getValue(); + assertThat(statusOrUpdate.hasValue()).isFalse(); + assertThat(statusOrUpdate.getStatus().getDescription()).contains(errorMsg); } @Test @@ -3158,10 +4117,10 @@ public void edsResourceDeletedByCds() { String resource = "backend-service.googleapis.com"; ResourceWatcher cdsWatcher = mock(ResourceWatcher.class); ResourceWatcher edsWatcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),resource, cdsWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),resource, edsWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, cdsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), resource, cdsWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), resource, edsWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); verifyResourceMetadataRequested(CDS, resource); verifyResourceMetadataRequested(EDS, EDS_RESOURCE); @@ -3176,12 +4135,13 @@ public void edsResourceDeletedByCds() { Any.pack(mf.buildEdsCluster(CDS_RESOURCE, EDS_RESOURCE, "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null, null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); - verify(cdsWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + ArgumentCaptor> cdsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(cdsWatcher, times(1)).onResourceChanged(cdsUpdateCaptor.capture()); + CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue().getValue(); assertThat(cdsUpdate.edsServiceName()).isEqualTo(null); assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(xdsServerInfo); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher, times(1)).onResourceChanged(cdsUpdateCaptor.capture()); + cdsUpdate = cdsUpdateCaptor.getValue().getValue(); assertThat(cdsUpdate.edsServiceName()).isEqualTo(EDS_RESOURCE); assertThat(cdsUpdate.lrsServerInfo()).isNull(); verifyResourceMetadataAcked(CDS, resource, clusters.get(0), VERSION_1, TIME_INCREMENT); @@ -3201,13 +4161,15 @@ public void edsResourceDeletedByCds() { mf.buildClusterLoadAssignment(resource, ImmutableList.of( mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", - mf.buildLbEndpoint("192.168.0.2", 9090, "healthy", 3), 1, 0)), + mf.buildLbEndpoint("192.168.0.2", 9090, "healthy", 3, + "endpoint-host-name"), 1, 0)), ImmutableList.of(mf.buildDropOverload("lb", 100))))); call.sendResponse(EDS, clusterLoadAssignments, VERSION_1, "0000"); - verify(edsWatcher).onChanged(edsUpdateCaptor.capture()); - assertThat(edsUpdateCaptor.getValue().clusterName).isEqualTo(resource); - verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - assertThat(edsUpdateCaptor.getValue().clusterName).isEqualTo(EDS_RESOURCE); + ArgumentCaptor> edsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(edsWatcher, times(1)).onResourceChanged(edsUpdateCaptor.capture()); + assertThat(edsUpdateCaptor.getValue().getValue().clusterName).isEqualTo(resource); + verify(edsResourceWatcher, times(1)).onResourceChanged(edsUpdateCaptor.capture()); + assertThat(edsUpdateCaptor.getValue().getValue().clusterName).isEqualTo(EDS_RESOURCE); verifyResourceMetadataAcked( EDS, EDS_RESOURCE, clusterLoadAssignments.get(0), VERSION_1, TIME_INCREMENT * 2); @@ -3225,12 +4187,8 @@ public void edsResourceDeletedByCds() { "envoy.transport_sockets.tls", null, null ))); call.sendResponse(CDS, clusters, VERSION_2, "0001"); - verify(cdsResourceWatcher, times(2)).onChanged(cdsUpdateCaptor.capture()); - assertThat(cdsUpdateCaptor.getValue().edsServiceName()).isNull(); - // Note that the endpoint must be deleted even if the ignore_resource_deletion feature. - // This happens because the cluster CDS_RESOURCE is getting replaced, and not deleted. - verify(edsResourceWatcher, never()).onResourceDoesNotExist(EDS_RESOURCE); - verify(edsResourceWatcher, never()).onResourceDoesNotExist(resource); + verify(cdsResourceWatcher, times(2)).onResourceChanged(cdsUpdateCaptor.capture()); + assertThat(cdsUpdateCaptor.getValue().getValue().edsServiceName()).isNull(); verifyNoMoreInteractions(cdsWatcher, edsWatcher); verifyResourceMetadataAcked( EDS, EDS_RESOURCE, clusterLoadAssignments.get(0), VERSION_1, TIME_INCREMENT * 2); @@ -3238,7 +4196,6 @@ public void edsResourceDeletedByCds() { EDS, resource, clusterLoadAssignments.get(1), VERSION_1, TIME_INCREMENT * 2); // no change verifyResourceMetadataAcked(CDS, resource, clusters.get(0), VERSION_2, TIME_INCREMENT * 3); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusters.get(1), VERSION_2, TIME_INCREMENT * 3); - verifySubscribedResourcesMetadataSizes(0, 2, 0, 2); } @Test @@ -3247,9 +4204,9 @@ public void multipleEdsWatchers() { String edsResourceTwo = "cluster-load-assignment-bar.googleapis.com"; ResourceWatcher watcher1 = mock(ResourceWatcher.class); ResourceWatcher watcher2 = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, edsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),edsResourceTwo, watcher1); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),edsResourceTwo, watcher2); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), edsResourceTwo, watcher1); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), edsResourceTwo, watcher2); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(EDS, Arrays.asList(EDS_RESOURCE, edsResourceTwo), "", "", NODE); verifyResourceMetadataRequested(EDS, EDS_RESOURCE); @@ -3257,16 +4214,24 @@ public void multipleEdsWatchers() { verifySubscribedResourcesMetadataSizes(0, 0, 0, 2); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); - verify(watcher1).onResourceDoesNotExist(edsResourceTwo); - verify(watcher2).onResourceDoesNotExist(edsResourceTwo); + verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getDescription().contains(EDS_RESOURCE))); + verify(watcher1).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getDescription().contains(edsResourceTwo))); + verify(watcher2).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getDescription().contains(edsResourceTwo))); verifyResourceMetadataDoesNotExist(EDS, EDS_RESOURCE); verifyResourceMetadataDoesNotExist(EDS, edsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 0, 0, 2); call.sendResponse(EDS, testClusterLoadAssignment, VERSION_1, "0000"); - verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - EdsUpdate edsUpdate = edsUpdateCaptor.getValue(); + verify(edsResourceWatcher, times(2)).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + EdsUpdate edsUpdate = statusOrUpdate.getValue(); validateGoldenClusterLoadAssignment(edsUpdate); verifyNoMoreInteractions(watcher1, watcher2); verifyResourceMetadataAcked( @@ -3278,12 +4243,15 @@ public void multipleEdsWatchers() { mf.buildClusterLoadAssignment(edsResourceTwo, ImmutableList.of( mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", - mf.buildLbEndpoint("172.44.2.2", 8000, "healthy", 3), 2, 0)), + mf.buildLbEndpoint("172.44.2.2", 8000, "healthy", 3, "endpoint-host-name"), + 2, 0)), ImmutableList.of())); call.sendResponse(EDS, clusterLoadAssignmentTwo, VERSION_2, "0001"); - verify(watcher1).onChanged(edsUpdateCaptor.capture()); - edsUpdate = edsUpdateCaptor.getValue(); + verify(watcher1, times(2)).onResourceChanged(edsUpdateCaptor.capture()); + statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + edsUpdate = statusOrUpdate.getValue(); assertThat(edsUpdate.clusterName).isEqualTo(edsResourceTwo); assertThat(edsUpdate.dropPolicies).isEmpty(); assertThat(edsUpdate.localityLbEndpointsMap) @@ -3291,9 +4259,13 @@ public void multipleEdsWatchers() { Locality.create("region2", "zone2", "subzone2"), LocalityLbEndpoints.create( ImmutableList.of( - LbEndpoint.create("172.44.2.2", 8000, 3, true)), 2, 0)); - verify(watcher2).onChanged(edsUpdateCaptor.capture()); - edsUpdate = edsUpdateCaptor.getValue(); + LbEndpoint.create("172.44.2.2", 8000, 3, + true, "endpoint-host-name", ImmutableMap.of())), + 2, 0, ImmutableMap.of())); + verify(watcher2, times(2)).onResourceChanged(edsUpdateCaptor.capture()); + statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + edsUpdate = statusOrUpdate.getValue(); assertThat(edsUpdate.clusterName).isEqualTo(edsResourceTwo); assertThat(edsUpdate.dropPolicies).isEmpty(); assertThat(edsUpdate.localityLbEndpointsMap) @@ -3301,7 +4273,9 @@ public void multipleEdsWatchers() { Locality.create("region2", "zone2", "subzone2"), LocalityLbEndpoints.create( ImmutableList.of( - LbEndpoint.create("172.44.2.2", 8000, 3, true)), 2, 0)); + LbEndpoint.create("172.44.2.2", 8000, 3, + true, "endpoint-host-name", ImmutableMap.of())), + 2, 0, ImmutableMap.of())); verifyNoMoreInteractions(edsResourceWatcher); verifyResourceMetadataAcked( EDS, edsResourceTwo, clusterLoadAssignmentTwo, VERSION_2, TIME_INCREMENT * 2); @@ -3322,43 +4296,113 @@ public void useIndependentRpcContext() { // The inbound RPC finishes and closes its context. The outbound RPC's control plane RPC // should not be impacted. cancellableContext.close(); - verify(ldsResourceWatcher, never()).onError(any(Status.class)); + verify(ldsResourceWatcher, never()).onAmbientError(any(Status.class)); + verify(ldsResourceWatcher, never()).onResourceChanged(argThat( + statusOr -> !statusOr.hasValue() + )); call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); - verify(ldsResourceWatcher).onChanged(any(LdsUpdate.class)); + verify(ldsResourceWatcher).onResourceChanged(any()); } finally { cancellableContext.detach(prevContext); } } + @Test + public void streamClosedWithNoResponse() { + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, + rdsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, true, xdsServerInfo.target()); + // Management server closes the RPC stream before sending any response. + call.sendCompleted(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, false, xdsServerInfo.target()); + verify(ldsResourceWatcher, Mockito.timeout(1000)).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr ldsStatusOr = ldsUpdateCaptor.getValue(); + assertThat(ldsStatusOr.hasValue()).isFalse(); + verifyStatusWithNodeId(ldsStatusOr.getStatus(), Code.UNAVAILABLE, + "ADS stream closed with OK before receiving a response"); + verify(rdsResourceWatcher, Mockito.timeout(1000)).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr rdsStatusOr = rdsUpdateCaptor.getValue(); + assertThat(rdsStatusOr.hasValue()).isFalse(); + verifyStatusWithNodeId(rdsStatusOr.getStatus(), Code.UNAVAILABLE, + "ADS stream closed with OK before receiving a response"); + } + + @Test + public void streamClosedAfterSendingResponses() { + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, + rdsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, true, xdsServerInfo.target()); + ScheduledTask ldsResourceTimeout = + Iterables.getOnlyElement(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)); + ScheduledTask rdsResourceTimeout = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)); + call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(2, true, xdsServerInfo.target()); + assertThat(ldsResourceTimeout.isCancelled()).isTrue(); + call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); + assertThat(rdsResourceTimeout.isCancelled()).isTrue(); + // Management server closes the RPC stream after sending responses. + call.sendCompleted(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(3, true, xdsServerInfo.target()); + verify(ldsResourceWatcher, never()).onAmbientError(any(Status.class)); + verify(rdsResourceWatcher, never()).onAmbientError(any(Status.class)); + verify(ldsResourceWatcher, times(1)).onResourceChanged(any()); + verify(rdsResourceWatcher, times(1)).onResourceChanged(any()); + } + @Test public void streamClosedAndRetryWithBackoff() { - InOrder inOrder = Mockito.inOrder(backoffPolicyProvider, backoffPolicy1, backoffPolicy2); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, ldsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, + InOrder inOrder = inOrder(backoffPolicyProvider, backoffPolicy1, backoffPolicy2); + InOrder ldsWatcherInOrder = inOrder(ldsResourceWatcher); + InOrder rdsWatcherInOrder = inOrder(rdsResourceWatcher); + InOrder cdsWatcherInOrder = inOrder(cdsResourceWatcher); + InOrder edsWatcherInOrder = inOrder(edsResourceWatcher); + when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2, backoffPolicy2); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, cdsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(LDS, LDS_RESOURCE, "", "", NODE); call.verifyRequest(RDS, RDS_RESOURCE, "", "", NODE); call.verifyRequest(CDS, CDS_RESOURCE, "", "", NODE); call.verifyRequest(EDS, EDS_RESOURCE, "", "", NODE); - // Management server closes the RPC stream with an error. + // Management server closes the RPC stream with an error. No response received yet. + fakeClock.forwardNanos(1000L); // Make sure retry isn't based on stopwatch 0 call.sendError(Status.UNKNOWN.asException()); - verify(ldsResourceWatcher, Mockito.timeout(1000).times(1)) - .onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNKNOWN, ""); - verify(rdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNKNOWN, ""); - verify(cdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNKNOWN, ""); - verify(edsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNKNOWN, ""); + ldsWatcherInOrder.verify(ldsResourceWatcher, timeout(1000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + rdsWatcherInOrder.verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + cdsWatcherInOrder.verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + edsWatcherInOrder.verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + + verifyServerFailureCount(1, 1, xdsServerInfo.target()); // Retry after backoff. - inOrder.verify(backoffPolicyProvider).get(); inOrder.verify(backoffPolicy1).nextBackoffNanos(); ScheduledTask retryTask = Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); @@ -3370,17 +4414,23 @@ public void streamClosedAndRetryWithBackoff() { call.verifyRequest(CDS, CDS_RESOURCE, "", "", NODE); call.verifyRequest(EDS, EDS_RESOURCE, "", "", NODE); - // Management server becomes unreachable. + // Management server becomes unreachable. No response received on this stream either. String errorMsg = "my fault"; call.sendError(Status.UNAVAILABLE.withDescription(errorMsg).asException()); - verify(ldsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); - verify(rdsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); - verify(cdsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); - verify(edsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + ldsWatcherInOrder.verify(ldsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + rdsWatcherInOrder.verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + cdsWatcherInOrder.verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + edsWatcherInOrder.verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + + verifyServerFailureCount(2, 1, xdsServerInfo.target()); // Retry after backoff. inOrder.verify(backoffPolicy1).nextBackoffNanos(); @@ -3399,41 +4449,49 @@ public void streamClosedAndRetryWithBackoff() { mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(2))))); call.sendResponse(LDS, listeners, "63", "3242"); call.verifyRequest(LDS, LDS_RESOURCE, "63", "3242", NODE); + ldsWatcherInOrder.verify(ldsResourceWatcher).onResourceChanged( + argThat(statusOr -> statusOr.hasValue())); List routeConfigs = ImmutableList.of( Any.pack(mf.buildRouteConfiguration(RDS_RESOURCE, mf.buildOpaqueVirtualHosts(2)))); call.sendResponse(RDS, routeConfigs, "5", "6764"); call.verifyRequest(RDS, RDS_RESOURCE, "5", "6764", NODE); + rdsWatcherInOrder.verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> statusOr.hasValue())); + // Stream fails AFTER a response. Error is suppressed and no watcher notification occurs. call.sendError(Status.DEADLINE_EXCEEDED.asException()); - verify(ldsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verify(rdsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verify(cdsResourceWatcher, times(3)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.DEADLINE_EXCEEDED, ""); - verify(edsResourceWatcher, times(3)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.DEADLINE_EXCEEDED, ""); + + // Failure count does NOT increase. + verifyServerFailureCount(2, 1, xdsServerInfo.target()); // Reset backoff sequence and retry after backoff. inOrder.verify(backoffPolicyProvider).get(); inOrder.verify(backoffPolicy2).nextBackoffNanos(); retryTask = Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); - assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(20L); - fakeClock.forwardNanos(20L); + fakeClock.forwardNanos(retryTask.getDelay(TimeUnit.NANOSECONDS)); call = resourceDiscoveryCalls.poll(); call.verifyRequest(LDS, LDS_RESOURCE, "63", "", NODE); call.verifyRequest(RDS, RDS_RESOURCE, "5", "", NODE); call.verifyRequest(CDS, CDS_RESOURCE, "", "", NODE); call.verifyRequest(EDS, EDS_RESOURCE, "", "", NODE); - // Management server becomes unreachable again. + // Management server becomes unreachable again. This is on a new stream, so error propagates. call.sendError(Status.UNAVAILABLE.asException()); - verify(ldsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verify(rdsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verify(cdsResourceWatcher, times(4)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); - verify(edsResourceWatcher, times(4)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); + ldsWatcherInOrder.verify(ldsResourceWatcher).onAmbientError( + argThat(status -> status.getCode() == Code.UNAVAILABLE)); + rdsWatcherInOrder.verify(rdsResourceWatcher).onAmbientError( + argThat(status -> status.getCode() == Code.UNAVAILABLE)); + cdsWatcherInOrder.verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + edsWatcherInOrder.verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + + // Failure count is now 3. + verifyServerFailureCount(3, 1, xdsServerInfo.target()); // Retry after backoff. inOrder.verify(backoffPolicy2).nextBackoffNanos(); @@ -3447,7 +4505,41 @@ public void streamClosedAndRetryWithBackoff() { call.verifyRequest(CDS, CDS_RESOURCE, "", "", NODE); call.verifyRequest(EDS, EDS_RESOURCE, "", "", NODE); - inOrder.verifyNoMoreInteractions(); + // Send a response so CPC is considered working and close gracefully. + call.sendResponse(LDS, listeners, "63", "3242"); + call.sendCompleted(); + + // Final failure count is still 3. + verifyServerFailureCount(3, 1, xdsServerInfo.target()); + } + + @Test + public void newWatcher_receivesCachedDataAndAmbientError() throws Exception { + InOrder inOrder = inOrder(ldsResourceWatcher); + DiscoveryRpcCall call1 = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + call1.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); + inOrder.verify(ldsResourceWatcher, timeout(5000)) + .onResourceChanged(argThat(statusOr -> statusOr.hasValue())); + + call1.sendError(Status.DEADLINE_EXCEEDED.asException()); + ScheduledTask retryTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); + fakeClock.forwardNanos(retryTask.getDelay(TimeUnit.NANOSECONDS)); + DiscoveryRpcCall call2 = resourceDiscoveryCalls.poll(); + Status propagatedError = Status.UNAVAILABLE.withDescription("real failure"); + call2.sendError(propagatedError.asException()); + inOrder.verify(ldsResourceWatcher, timeout(5000)).onAmbientError( + argThat(status -> status.getCode() == Code.UNAVAILABLE)); + @SuppressWarnings("unchecked") + ResourceWatcher ldsResourceWatcher2 = mock(ResourceWatcher.class); + xdsClient.watchXdsResource( + XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher2); + + verify(ldsResourceWatcher2, timeout(5000)).onResourceChanged( + argThat(statusOr -> statusOr.hasValue())); + verify(ldsResourceWatcher2, timeout(5000)).onAmbientError( + argThat(status -> status.getCode() == Code.UNAVAILABLE)); } @Test @@ -3457,16 +4549,23 @@ public void streamClosedAndRetryRaceWithAddRemoveWatchers() { xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, true, xdsServerInfo.target()); call.sendError(Status.UNAVAILABLE.asException()); verify(ldsResourceWatcher, Mockito.timeout(1000).times(1)) - .onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); - verify(rdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); + .onResourceChanged(ldsUpdateCaptor.capture()); + verifyStatusWithNodeId(ldsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, ""); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + verifyStatusWithNodeId(rdsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, ""); ScheduledTask retryTask = Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(10L); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, false, xdsServerInfo.target()); + xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); xdsClient.cancelXdsResourceWatch(XdsRouteConfigureResource.getInstance(), @@ -3481,11 +4580,19 @@ public void streamClosedAndRetryRaceWithAddRemoveWatchers() { call.verifyRequest(EDS, EDS_RESOURCE, "", "", NODE); call.verifyNoMoreRequest(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(2,false, xdsServerInfo.target()); + call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); List routeConfigs = ImmutableList.of( Any.pack(mf.buildRouteConfiguration(RDS_RESOURCE, mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); call.sendResponse(RDS, routeConfigs, VERSION_1, "0000"); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(2, true, xdsServerInfo.target()); + verifyNoMoreInteractions(ldsResourceWatcher, rdsResourceWatcher); } @@ -3497,6 +4604,9 @@ public void streamClosedAndRetryRestartsResourceInitialFetchTimerForUnresolvedRe xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, true, xdsServerInfo.target()); ScheduledTask ldsResourceTimeout = Iterables.getOnlyElement(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)); ScheduledTask rdsResourceTimeout = @@ -3507,19 +4617,46 @@ public void streamClosedAndRetryRestartsResourceInitialFetchTimerForUnresolvedRe Iterables.getOnlyElement(fakeClock.getPendingTasks(EDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)); call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); assertThat(ldsResourceTimeout.isCancelled()).isTrue(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(2, true, xdsServerInfo.target()); call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); assertThat(rdsResourceTimeout.isCancelled()).isTrue(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(3, true, xdsServerInfo.target()); call.sendError(Status.UNAVAILABLE.asException()); assertThat(cdsResourceTimeout.isCancelled()).isTrue(); assertThat(edsResourceTimeout.isCancelled()).isTrue(); - verify(ldsResourceWatcher, never()).onError(errorCaptor.capture()); - verify(rdsResourceWatcher, never()).onError(errorCaptor.capture()); - verify(cdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); - verify(edsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); + + // With the reverted logic, the first error is suppressed because a response was received. + // We verify that no error callbacks are invoked at this point. + verify(ldsResourceWatcher, never()).onAmbientError(any(Status.class)); + verify(rdsResourceWatcher, never()).onAmbientError(any(Status.class)); + + // The metric report for a failed server connection is also suppressed. + callback_ReportServerConnection(); + verifyServerConnection(4, true, xdsServerInfo.target()); + + fakeClock.forwardTime(5, TimeUnit.SECONDS); + DiscoveryRpcCall call2 = resourceDiscoveryCalls.poll(); + call2.sendError(Status.UNAVAILABLE.asException()); + + // Now, verify the watchers are notified as expected. + verify(ldsResourceWatcher).onAmbientError(any(Status.class)); + verify(rdsResourceWatcher).onAmbientError(any(Status.class)); + verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + + fakeClock.forwardTime(5, TimeUnit.SECONDS); + DiscoveryRpcCall call3 = resourceDiscoveryCalls.poll(); + assertThat(call3).isNotNull(); fakeClock.forwardNanos(10L); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).hasSize(0); @@ -3538,13 +4675,13 @@ public void reportLoadStatsToServer() { lrsCall.sendResponse(Collections.singletonList(clusterName), 1000L); fakeClock.forwardNanos(1000L); - lrsCall.verifyNextReportClusters(Collections.singletonList(new String[] {clusterName, null})); + lrsCall.verifyNextReportClusters(Collections.singletonList(new String[]{clusterName, null})); dropStats.release(); fakeClock.forwardNanos(1000L); // In case of having unreported cluster stats, one last report will be sent after corresponding // stats object released. - lrsCall.verifyNextReportClusters(Collections.singletonList(new String[] {clusterName, null})); + lrsCall.verifyNextReportClusters(Collections.singletonList(new String[]{clusterName, null})); fakeClock.forwardNanos(1000L); // Currently load reporting continues (with empty stats) even if all stats objects have been @@ -3573,8 +4710,10 @@ public void serverSideListenerFound() { call.sendResponse(LDS, listeners, "0", "0000"); // Client sends an ACK LDS request. call.verifyRequest(LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - EnvoyServerProtoData.Listener parsedListener = ldsUpdateCaptor.getValue().listener(); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + EnvoyServerProtoData.Listener parsedListener = statusOrUpdate.getValue().listener(); assertThat(parsedListener.name()).isEqualTo(LISTENER_RESOURCE); assertThat(parsedListener.address()).isEqualTo("0.0.0.0:7000"); assertThat(parsedListener.defaultFilterChain()).isNull(); @@ -3611,25 +4750,26 @@ public void serverSideListenerNotFound() { verifyNoInteractions(ldsResourceWatcher); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(ldsResourceWatcher).onResourceDoesNotExist(LISTENER_RESOURCE); + verify(ldsResourceWatcher).onResourceChanged(argThat( + statusOr -> statusOr.getStatus().getDescription().contains(LISTENER_RESOURCE))); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); } @Test public void serverSideListenerResponseErrorHandling_badDownstreamTlsContext() { GrpcXdsClientImplTestBase.DiscoveryRpcCall call = - startResourceWatcher(XdsListenerResource.getInstance(), LISTENER_RESOURCE, - ldsResourceWatcher); + startResourceWatcher(XdsListenerResource.getInstance(), LISTENER_RESOURCE, + ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( - "route-foo.googleapis.com", null, + "route-foo.googleapis.com", null, Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( - null, null,false); + null, null, false); Message filterChain = mf.buildFilterChain( - Collections.emptyList(), downstreamTlsContext, "envoy.transport_sockets.tls", + Collections.emptyList(), downstreamTlsContext, "envoy.transport_sockets.tls", hcmFilter); Message listener = - mf.buildListenerWithFilterChain(LISTENER_RESOURCE, 7000, "0.0.0.0", filterChain); + mf.buildListenerWithFilterChain(LISTENER_RESOURCE, 7000, "0.0.0.0", filterChain); List listeners = ImmutableList.of(Any.pack(listener)); call.sendResponse(LDS, listeners, "0", "0000"); // The response NACKed with errors indicating indices of the failed resources. @@ -3637,8 +4777,10 @@ public void serverSideListenerResponseErrorHandling_badDownstreamTlsContext() { + "0.0.0.0:7000\' validation error: " + "common-tls-context is required in downstream-tls-context"; call.verifyRequestNack(LDS, LISTENER_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); - verify(ldsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isFalse(); + verifyStatusWithNodeId(statusOrUpdate.getStatus(), Code.UNAVAILABLE, errorMsg); } @Test @@ -3650,7 +4792,7 @@ public void serverSideListenerResponseErrorHandling_badTransportSocketName() { "route-foo.googleapis.com", null, Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( - "cert1", "cert2",false); + "cert1", "cert2", false); Message filterChain = mf.buildFilterChain( Collections.emptyList(), downstreamTlsContext, "envoy.transport_sockets.bad1", hcmFilter); @@ -3664,8 +4806,8 @@ public void serverSideListenerResponseErrorHandling_badTransportSocketName() { + "transport-socket with name envoy.transport_sockets.bad1 not supported."; call.verifyRequestNack(LDS, LISTENER_RESOURCE, "", "0000", NODE, ImmutableList.of( errorMsg)); - verify(ldsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + verifyStatusWithNodeId(ldsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, errorMsg); } @Test @@ -3681,6 +4823,9 @@ public void sendingToStoppedServer() throws Exception { xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); fakeClock.forwardTime(14, TimeUnit.SECONDS); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, false, xdsServerInfo.target()); // Restart the server xdsServer = cleanupRule.register( @@ -3692,13 +4837,21 @@ public void sendingToStoppedServer() throws Exception { .build() .start()); fakeClock.forwardTime(5, TimeUnit.SECONDS); - verify(ldsResourceWatcher, never()).onResourceDoesNotExist(LDS_RESOURCE); + verify(ldsResourceWatcher, never()).onResourceChanged(argThat( + statusOr -> statusOr.getStatus().getDescription().contains(LDS_RESOURCE))); fakeClock.forwardTime(20, TimeUnit.SECONDS); // Trigger rpcRetryTimer DiscoveryRpcCall call = resourceDiscoveryCalls.poll(3, TimeUnit.SECONDS); + // Check metric data. + callback_ReportServerConnection(); if (call == null) { // The first rpcRetry may have happened before the channel was ready fakeClock.forwardTime(50, TimeUnit.SECONDS); call = resourceDiscoveryCalls.poll(3, TimeUnit.SECONDS); } + verifyServerConnection(2, false, xdsServerInfo.target()); + + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(3, false, xdsServerInfo.target()); // NOTE: There is a ScheduledExecutorService that may get involved due to the reconnect // so you cannot rely on the logic being single threaded. The timeout() in verifyRequest @@ -3706,11 +4859,18 @@ public void sendingToStoppedServer() throws Exception { // Send a response and do verifications call.sendResponse(LDS, mf.buildWrappedResource(testListenerVhosts), VERSION_1, "0001"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0001", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + @SuppressWarnings("unchecked") + ArgumentCaptor> captor = ArgumentCaptor.forClass(StatusOr.class); + verify(ldsResourceWatcher, timeout(1000).atLeast(2)).onResourceChanged(captor.capture()); + StatusOr lastValue = captor.getAllValues().get(captor.getAllValues().size() - 1); + assertThat(lastValue.hasValue()).isTrue(); + verifyGoldenListenerVhosts(lastValue.getValue()); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 1, 0, 0); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, true, xdsServerInfo.target()); } catch (Throwable t) { throw t; // This allows putting a breakpoint here for debugging } @@ -3719,14 +4879,38 @@ public void sendingToStoppedServer() throws Exception { @Test public void sendToBadUrl() throws Exception { // Setup xdsClient to fail on stream creation - XdsClientImpl client = createXdsClient("some. garbage"); + String garbageUri = "some. garbage"; + XdsClientImpl client = createXdsClient(garbageUri); client.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); fakeClock.forwardTime(20, TimeUnit.SECONDS); - verify(ldsResourceWatcher, Mockito.timeout(5000).times(1)).onError(ArgumentMatchers.any()); + verify(ldsResourceWatcher, Mockito.timeout(5000).atLeastOnce()) + .onResourceChanged(ldsUpdateCaptor.capture()); + assertThat(ldsUpdateCaptor.getValue().getStatus().getDescription()).contains(garbageUri); client.shutdown(); } + @Test + public void circuitBreakingConversionOf32bitIntTo64bitLongForMaxRequestNegativeValue() { + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + Any clusterCircuitBreakers = Any.pack( + mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, null, false, null, + "envoy.transport_sockets.tls", mf.buildCircuitBreakers(50, -1), null)); + call.sendResponse(CDS, clusterCircuitBreakers, VERSION_1, "0000"); + + // Client sent an ACK CDS request. + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); + + assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); + assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); + assertThat(cdsUpdate.maxConcurrentRequests()).isEqualTo(4294967295L); + } + @Test public void sendToNonexistentServer() throws Exception { // Setup xdsClient to fail on stream creation @@ -3735,30 +4919,228 @@ public void sendToNonexistentServer() throws Exception { // file. Assume localhost doesn't speak HTTP/2 on the finger port XdsClientImpl client = createXdsClient("localhost:79"); client.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); - verify(ldsResourceWatcher, Mockito.timeout(5000).times(1)).onError(ArgumentMatchers.any()); + verify(ldsResourceWatcher, Mockito.timeout(5000)).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isFalse(); + assertThat(statusOrUpdate.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); assertThat(fakeClock.numPendingTasks()).isEqualTo(1); //retry assertThat(fakeClock.getPendingTasks().iterator().next().toString().contains("RpcRetryTask")) .isTrue(); client.shutdown(); } + @Test + public void validAndInvalidResourceMetricReport() { + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "A", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "A.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "B", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "B.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "C", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "C.1", edsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + assertThat(call).isNotNull(); + + // CDS -> {A, B, C}, version 1 + ImmutableMap resourcesV1 = ImmutableMap.of( + "A", Any.pack(mf.buildEdsCluster("A", "A.1", "round_robin", null, null, false, null, + "envoy.transport_sockets.tls", null, null + )), + "B", Any.pack(mf.buildEdsCluster("B", "B.1", "round_robin", null, null, false, null, + "envoy.transport_sockets.tls", null, null + )), + "C", Any.pack(mf.buildEdsCluster("C", "C.1", "round_robin", null, null, false, null, + "envoy.transport_sockets.tls", null, null + ))); + call.sendResponse(CDS, resourcesV1.values().asList(), VERSION_1, "0000"); + // {A, B, C} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), CDS.typeUrl()); + + // EDS -> {A.1, B.1, C.1}, version 1 + List dropOverloads = ImmutableList.of(); + List endpointsV1 = ImmutableList.of(lbEndpointHealthy); + ImmutableMap resourcesV11 = ImmutableMap.of( + "A.1", Any.pack(mf.buildClusterLoadAssignment("A.1", endpointsV1, dropOverloads)), + "B.1", Any.pack(mf.buildClusterLoadAssignment("B.1", endpointsV1, dropOverloads)), + "C.1", Any.pack(mf.buildClusterLoadAssignment("C.1", endpointsV1, dropOverloads))); + call.sendResponse(EDS, resourcesV11.values().asList(), VERSION_1, "0000"); + // {A.1, B.1, C.1} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), EDS.typeUrl()); + + // CDS -> {A, B}, version 2 + // Failed to parse endpoint B + ImmutableMap resourcesV2 = ImmutableMap.of( + "A", Any.pack(mf.buildEdsCluster("A", "A.2", "round_robin", null, null, false, null, + "envoy.transport_sockets.tls", null, null + )), + "B", Any.pack(mf.buildClusterInvalid("B"))); + call.sendResponse(CDS, resourcesV2.values().asList(), VERSION_2, "0001"); + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {C} -> does not exist + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), CDS.typeUrl()); + } + + @Test + public void serverFailureMetricReport() { + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, + rdsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + // Management server closes the RPC stream before sending any response. + call.sendCompleted(); + verify(ldsResourceWatcher, Mockito.timeout(1000)).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr ldsStatusOr = ldsUpdateCaptor.getValue(); + assertThat(ldsStatusOr.hasValue()).isFalse(); + verifyStatusWithNodeId(ldsStatusOr.getStatus(), Code.UNAVAILABLE, + "ADS stream closed with OK before receiving a response"); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr rdsStatusOr = rdsUpdateCaptor.getValue(); + assertThat(rdsStatusOr.hasValue()).isFalse(); + verifyStatusWithNodeId(rdsStatusOr.getStatus(), Code.UNAVAILABLE, + "ADS stream closed with OK before receiving a response"); + verifyServerFailureCount(1, 1, xdsServerInfo.target()); + } + + @Test + public void serverFailureMetricReport_forRetryAndBackoff() { + InOrder inOrder = inOrder(backoffPolicyProvider, backoffPolicy1, backoffPolicy2); + InOrder ldsWatcherInOrder = inOrder(ldsResourceWatcher); + InOrder rdsWatcherInOrder = inOrder(rdsResourceWatcher); + InOrder cdsWatcherInOrder = inOrder(cdsResourceWatcher); + InOrder edsWatcherInOrder = inOrder(edsResourceWatcher); + when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2, backoffPolicy2); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, + rdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + + // Management server closes the RPC stream with an error. + call.sendError(Status.UNKNOWN.asException()); + ldsWatcherInOrder.verify(ldsResourceWatcher, timeout(1000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + rdsWatcherInOrder.verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + cdsWatcherInOrder.verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + edsWatcherInOrder.verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + verifyServerFailureCount(1, 1, xdsServerInfo.target()); + + // Retry after backoff. + inOrder.verify(backoffPolicyProvider).get(); + inOrder.verify(backoffPolicy1).nextBackoffNanos(); + ScheduledTask retryTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); + assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(10L); + fakeClock.forwardNanos(10L); + call = resourceDiscoveryCalls.poll(); + + // Management server becomes unreachable. + String errorMsg = "my fault"; + call.sendError(Status.UNAVAILABLE.withDescription(errorMsg).asException()); + ldsWatcherInOrder.verify(ldsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + rdsWatcherInOrder.verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + cdsWatcherInOrder.verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + edsWatcherInOrder.verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + verifyServerFailureCount(2, 1, xdsServerInfo.target()); + + // Retry after backoff. + inOrder.verify(backoffPolicy1).nextBackoffNanos(); + retryTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); + assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(100L); + fakeClock.forwardNanos(100L); + call = resourceDiscoveryCalls.poll(); + + List resources = ImmutableList.of(FAILING_ANY, testListenerRds, FAILING_ANY); + call.sendResponse(LDS, resources, "63", "3242"); + ldsWatcherInOrder.verify(ldsResourceWatcher).onResourceChanged( + argThat(statusOr -> statusOr.hasValue())); + + List routeConfigs = ImmutableList.of(FAILING_ANY, testRouteConfig, FAILING_ANY); + call.sendResponse(RDS, routeConfigs, "5", "6764"); + rdsWatcherInOrder.verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> statusOr.hasValue())); + + // Stream fails AFTER a response. Error is suppressed and no watcher notification occurs. + call.sendError(Status.DEADLINE_EXCEEDED.asException()); + + // Failure count does NOT increase because the error was suppressed. It is still 2. + verifyServerFailureCount(2, 1, xdsServerInfo.target()); + + // Reset backoff sequence and retry after backoff. + inOrder.verify(backoffPolicyProvider).get(); + inOrder.verify(backoffPolicy2).nextBackoffNanos(); + retryTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); + assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(20L); + fakeClock.forwardNanos(20L); + call = resourceDiscoveryCalls.poll(); + + // Management server becomes unreachable again. This is on a new stream, so error propagates. + call.sendError(Status.UNAVAILABLE.asException()); + ldsWatcherInOrder.verify(ldsResourceWatcher).onAmbientError( + argThat(status -> status.getCode() == Code.UNAVAILABLE)); + rdsWatcherInOrder.verify(rdsResourceWatcher).onAmbientError( + argThat(status -> status.getCode() == Code.UNAVAILABLE)); + cdsWatcherInOrder.verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + edsWatcherInOrder.verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + + // Server failure count is now 3. + verifyServerFailureCount(3, 1, xdsServerInfo.target()); + + // Retry after backoff. + inOrder.verify(backoffPolicy2).nextBackoffNanos(); + retryTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); + assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(200L); + fakeClock.forwardNanos(200L); + call = resourceDiscoveryCalls.poll(); + + List clusters = ImmutableList.of(FAILING_ANY, testClusterRoundRobin); + call.sendResponse(CDS, clusters, VERSION_1, "0000"); + call.sendCompleted(); + + // Final failure count is still 3 as the stream closed gracefully. + verifyServerFailureCount(3, 1, xdsServerInfo.target()); + } + private XdsClientImpl createXdsClient(String serverUri) { BootstrapInfo bootstrapInfo = buildBootStrap(serverUri); return new XdsClientImpl( - DEFAULT_XDS_TRANSPORT_FACTORY, + new GrpcXdsTransportFactory(null), bootstrapInfo, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, fakeClock.getStopwatchSupplier(), timeProvider, MessagePrinter.INSTANCE, - new TlsContextManagerImpl(bootstrapInfo)); + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); } - private BootstrapInfo buildBootStrap(String serverUri) { + private BootstrapInfo buildBootStrap(String serverUri) { ServerInfo xdsServerInfo = ServerInfo.create(serverUri, CHANNEL_CREDENTIALS, - ignoreResourceDeletion()); + ignoreResourceDeletion(), true, false, false); return Bootstrapper.BootstrapInfo.builder() .servers(Collections.singletonList(xdsServerInfo)) @@ -3768,7 +5150,7 @@ private BootstrapInfo buildBootStrap(String serverUri) { AuthorityInfo.create( "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/%s", ImmutableList.of(Bootstrapper.ServerInfo.create( - SERVER_URI_CUSTOME_AUTHORITY, CHANNEL_CREDENTIALS))), + SERVER_URI_CUSTOM_AUTHORITY, CHANNEL_CREDENTIALS))), "", AuthorityInfo.create( "xdstp:///envoy.config.listener.v3.Listener/%s", @@ -3862,7 +5244,7 @@ protected void sendResponse( } protected void sendResponse(XdsResourceType type, Any resource, String versionInfo, - String nonce) { + String nonce) { sendResponse(type, ImmutableList.of(resource), versionInfo, nonce); } @@ -3894,6 +5276,7 @@ protected void sendResponse(List clusters, long loadReportIntervalNano) } protected abstract static class MessageFactory { + /** Throws {@link InvalidProtocolBufferException} on {@link Any#unpack(Class)}. */ protected static final Any FAILING_ANY = Any.newBuilder().setTypeUrl("fake").build(); @@ -3977,7 +5360,7 @@ protected Message buildLocalityLbEndpoints(String region, String zone, String su } protected abstract Message buildLbEndpoint(String address, int port, String healthStatus, - int lbWeight); + int lbWeight, String endpointHostname); protected abstract Message buildDropOverload(String category, int dropPerMillion); @@ -3993,4 +5376,70 @@ protected abstract Message buildHttpConnectionManagerFilter( protected abstract Message buildTerminalFilter(); } + + private static class XdsStringResource extends XdsResourceType { + @Override + @SuppressWarnings("unchecked") + protected Class unpackedClassName() { + return StringValue.class; + } + + @Override + public String typeName() { + return "EMPTY"; + } + + @Override + public String typeUrl() { + return "type.googleapis.com/google.protobuf.StringValue"; + } + + @Override + public boolean shouldRetrieveResourceKeysForArgs() { + return false; + } + + @Override + protected boolean isFullStateOfTheWorld() { + return false; + } + + @Override + @Nullable + protected String extractResourceName(Message unpackedResource) { + if (!(unpackedResource instanceof StringValue)) { + return null; + } + return ((StringValue) unpackedResource).getValue(); + } + + @Override + protected StringUpdate doParse(Args args, Message unpackedMessage) + throws ResourceInvalidException { + return new StringUpdate(((StringValue) unpackedMessage).getValue()); + } + } + + private static final class StringUpdate implements ResourceUpdate { + @SuppressWarnings("UnusedVariable") + public final String value; + + public StringUpdate(String value) { + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof StringUpdate)) { + return false; + } + StringUpdate that = (StringUpdate) o; + return Objects.equals(this.value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(value); + } + } } diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplV3Test.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplV3Test.java index 40a9bff514f..3966fae7f20 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplV3Test.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplV3Test.java @@ -613,18 +613,15 @@ protected Message buildLeastRequestLbConfig(int choiceCount) { } @Override - @SuppressWarnings("deprecation") protected Message buildUpstreamTlsContext(String instanceName, String certName) { CommonTlsContext.Builder commonTlsContextBuilder = CommonTlsContext.newBuilder(); if (instanceName != null && certName != null) { - CommonTlsContext.CertificateProviderInstance providerInstance = - CommonTlsContext.CertificateProviderInstance.newBuilder() - .setInstanceName(instanceName) - .setCertificateName(certName) - .build(); CommonTlsContext.CombinedCertificateValidationContext combined = CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance(providerInstance) + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance(CertificateProviderPluginInstance.newBuilder() + .setInstanceName(instanceName) + .setCertificateName(certName))) .build(); commonTlsContextBuilder.setCombinedValidationContext(combined); } @@ -705,7 +702,7 @@ protected Message buildLocalityLbEndpoints(String region, String zone, String su @Override protected Message buildLbEndpoint(String address, int port, String healthStatus, - int lbWeight) { + int lbWeight, String endpointHostname) { HealthStatus status; switch (healthStatus) { case "unknown": @@ -733,7 +730,8 @@ protected Message buildLbEndpoint(String address, int port, String healthStatus, .setEndpoint( Endpoint.newBuilder().setAddress( Address.newBuilder().setSocketAddress( - SocketAddress.newBuilder().setAddress(address).setPortValue(port)))) + SocketAddress.newBuilder().setAddress(address).setPortValue(port))) + .setHostname(endpointHostname)) .setHealthStatus(status) .setLoadBalancingWeight(UInt32Value.of(lbWeight)) .build(); @@ -750,7 +748,6 @@ protected Message buildDropOverload(String category, int dropPerMillion) { .build(); } - @SuppressWarnings("deprecation") @Override protected FilterChain buildFilterChain( List alpn, Message tlsContext, String transportSocketName, diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java index 703e429fa23..9c606a962f6 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java @@ -30,6 +30,7 @@ import io.grpc.Server; import io.grpc.Status; import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.XdsTransportFactory; import java.util.concurrent.BlockingQueue; @@ -37,6 +38,7 @@ import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -44,6 +46,8 @@ @RunWith(JUnit4.class) public class GrpcXdsTransportFactoryTest { + @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + private Server server; @Before @@ -92,9 +96,10 @@ public void onCompleted() { @Test public void callApis() throws Exception { XdsTransportFactory.XdsTransport xdsTransport = - GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.create( - Bootstrapper.ServerInfo.create("localhost:" + server.getPort(), - InsecureChannelCredentials.create())); + new GrpcXdsTransportFactory(null) + .create( + Bootstrapper.ServerInfo.create( + "localhost:" + server.getPort(), InsecureChannelCredentials.create())); MethodDescriptor methodDescriptor = AggregatedDiscoveryServiceGrpc.getStreamAggregatedResourcesMethod(); XdsTransportFactory.StreamingCall streamingCall = @@ -117,6 +122,59 @@ public void callApis() throws Exception { xdsTransport.shutdown(); } + @Test + public void refCountedXdsTransport_sameXdsServerAddress_returnsExistingTransport() { + Bootstrapper.ServerInfo xdsServerInfo = + Bootstrapper.ServerInfo.create( + "localhost:" + server.getPort(), InsecureChannelCredentials.create()); + GrpcXdsTransportFactory xdsTransportFactory = new GrpcXdsTransportFactory(null); + // Calling create() for the first time creates a new GrpcXdsTransport instance. + // The ref count was previously 0 and now is 1. + XdsTransportFactory.XdsTransport transport1 = xdsTransportFactory.create(xdsServerInfo); + // Calling create() for the second time to the same xDS server address returns the same + // GrpcXdsTransport instance. The ref count was previously 1 and now is 2. + XdsTransportFactory.XdsTransport transport2 = xdsTransportFactory.create(xdsServerInfo); + assertThat(transport1).isSameInstanceAs(transport2); + // Calling shutdown() for the first time does not shut down the GrpcXdsTransport instance. + // The ref count was previously 2 and now is 1. + transport1.shutdown(); + // Calling shutdown() for the second time shuts down the GrpcXdsTransport instance. + // The ref count was previously 1 and now is 0. + transport2.shutdown(); + } + + @Test + public void refCountedXdsTransport_differentXdsServerAddress_returnsDifferentTransport() + throws Exception { + // Create and start a second xDS server on a different port. + Server server2 = + grpcCleanupRule.register( + Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(echoAdsService()) + .build() + .start()); + Bootstrapper.ServerInfo xdsServerInfo1 = + Bootstrapper.ServerInfo.create( + "localhost:" + server.getPort(), InsecureChannelCredentials.create()); + Bootstrapper.ServerInfo xdsServerInfo2 = + Bootstrapper.ServerInfo.create( + "localhost:" + server2.getPort(), InsecureChannelCredentials.create()); + GrpcXdsTransportFactory xdsTransportFactory = new GrpcXdsTransportFactory(null); + // Calling create() to the first xDS server creates a new GrpcXdsTransport instance. + // The ref count was previously 0 and now is 1. + XdsTransportFactory.XdsTransport transport1 = xdsTransportFactory.create(xdsServerInfo1); + // Calling create() to the second xDS server creates a different GrpcXdsTransport instance. + // The ref count was previously 0 and now is 1. + XdsTransportFactory.XdsTransport transport2 = xdsTransportFactory.create(xdsServerInfo2); + assertThat(transport1).isNotSameInstanceAs(transport2); + // Calling shutdown() shuts down the GrpcXdsTransport instance for the first xDS server. + // The ref count was previously 1 and now is 0. + transport1.shutdown(); + // Calling shutdown() shuts down the GrpcXdsTransport instance for the second xDS server. + // The ref count was previously 1 and now is 0. + transport2.shutdown(); + } + private static class FakeEventHandler implements XdsTransportFactory.EventHandler { private final BlockingQueue respQ = new LinkedBlockingQueue<>(); diff --git a/xds/src/test/java/io/grpc/xds/LazyLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/LazyLoadBalancerTest.java new file mode 100644 index 00000000000..c79d048c9d3 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/LazyLoadBalancerTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; + +import io.grpc.CallOptions; +import io.grpc.ConnectivityState; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.SynchronizationContext; +import io.grpc.internal.PickSubchannelArgsImpl; +import io.grpc.testing.TestMethodDescriptors; +import java.util.Arrays; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit test for {@link io.grpc.xds.LazyLoadBalancer}. */ +@RunWith(JUnit4.class) +public final class LazyLoadBalancerTest { + private SynchronizationContext syncContext = + new SynchronizationContext((t, e) -> { + throw new AssertionError(e); + }); + private LoadBalancer.PickSubchannelArgs args = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), + new Metadata(), + CallOptions.DEFAULT, + new LoadBalancer.PickDetailsConsumer() {}); + private FakeHelper helper = new FakeHelper(); + + @Test + public void pickerIsNoopAfterEarlyShutdown() { + LazyLoadBalancer lb = new LazyLoadBalancer(helper, new LoadBalancer.Factory() { + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + throw new AssertionError("unexpected"); + } + }); + lb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(Arrays.asList()) + .build()); + SubchannelPicker picker = helper.picker; + assertThat(picker).isNotNull(); + lb.shutdown(); + + picker.pickSubchannel(args); + } + + class FakeHelper extends LoadBalancer.Helper { + ConnectivityState state; + SubchannelPicker picker; + + @Override + public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + this.state = newState; + this.picker = newPicker; + } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + + @Override + public String getAuthority() { + return "localhost"; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java index 659bacd3626..302faed95a4 100644 --- a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java @@ -22,6 +22,7 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.LoadBalancerMatchers.pickerReturns; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -50,6 +51,7 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.FixedResultPicker; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -62,7 +64,6 @@ import io.grpc.internal.PickFirstLoadBalancerProvider; import io.grpc.util.AbstractTestHelper; import io.grpc.util.MultiChildLoadBalancer.ChildLbState; -import io.grpc.xds.LeastRequestLoadBalancer.EmptyPicker; import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestLbState; import io.grpc.xds.LeastRequestLoadBalancer.ReadyPicker; @@ -186,8 +187,7 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { Subchannel removedSubchannel = getSubchannel(removedEag); Subchannel oldSubchannel = getSubchannel(oldEag1); SubchannelStateListener removedListener = - testHelperInstance.getSubchannelStateListeners() - .get(testHelperInstance.getRealForMockSubChannel(removedSubchannel)); + testHelperInstance.getSubchannelStateListener(removedSubchannel); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -201,8 +201,6 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { verify(removedSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).requestConnection(); - assertThat(getChildEags(loadBalancer)).containsExactly(removedEag, oldEag1); - // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); @@ -219,8 +217,6 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { removedListener.onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(getChildEags(loadBalancer)).containsExactly(oldEag2, newEag); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); @@ -233,28 +229,6 @@ private Subchannel getSubchannel(EquivalentAddressGroup removedEag) { return subchannels.get(Collections.singletonList(removedEag)); } - private Subchannel getSubchannel(ChildLbState childLbState) { - return subchannels.get(Collections.singletonList(childLbState.getEag())); - } - - private static List getChildEags(LeastRequestLoadBalancer loadBalancer) { - return loadBalancer.getChildLbStates().stream() - .map(ChildLbState::getEag) - // .map(EquivalentAddressGroup::getAddresses) - .collect(Collectors.toList()); - } - - private List getSubchannels(LeastRequestLoadBalancer lb) { - return lb.getChildLbStates().stream() - .map(this::getSubchannel) - .collect(Collectors.toList()); - } - - private LeastRequestLbState getChildLbState(PickResult pickResult) { - EquivalentAddressGroup eag = pickResult.getSubchannel().getAddresses(); - return (LeastRequestLbState) loadBalancer.getChildLbState(eag); - } - @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(helper); @@ -263,9 +237,10 @@ public void pickAfterStateChange() throws Exception { .build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); - Subchannel subchannel = getSubchannel(childLbState); + Subchannel subchannel = getSubchannel(servers.get(0)); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); @@ -278,7 +253,8 @@ public void pickAfterStateChange() throws Exception { assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); assertThat(childLbState.getCurrentPicker().toString()).contains(error.toString()); refreshInvokedAndUpdateBS(inOrder, CONNECTING); - assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs)) + .isEqualTo(PickResult.withNoResult()); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(helper).refreshNameResolution(); @@ -329,10 +305,12 @@ public void ignoreShutdownSubchannelStateChange() { ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); + List savedSubchannels = new ArrayList<>(subchannels.values()); loadBalancer.shutdown(); - for (Subchannel sc : getSubchannels(loadBalancer)) { + for (Subchannel sc : savedSubchannels) { verify(sc).shutdown(); // When the subchannel is being shut down, a SHUTDOWN connectivity state is delivered // back to the subchannel state listener. @@ -350,11 +328,14 @@ public void stayTransientFailureUntilReady() { .build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); // Simulate state transitions for each subchannel individually. - for (ChildLbState childLbState : loadBalancer.getChildLbStates()) { - Subchannel sc = getSubchannel(childLbState); + List children = new ArrayList<>(loadBalancer.getChildLbStates()); + for (int i = 0; i < children.size(); i++) { + ChildLbState childLbState = children.get(i); + Subchannel sc = getSubchannel(servers.get(i)); Status error = Status.UNKNOWN.withDescription("connection broken"); deliverSubchannelState(sc, ConnectivityStateInfo.forTransientFailure(error)); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); @@ -369,7 +350,7 @@ public void stayTransientFailureUntilReady() { inOrder.verifyNoMoreInteractions(); ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); - Subchannel subchannel = getSubchannel(childLbState); + Subchannel subchannel = getSubchannel(servers.get(0)); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); assertThat(childLbState.getCurrentState()).isEqualTo(READY); inOrder.verify(helper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); @@ -408,10 +389,11 @@ public void refreshNameResolutionWhenSubchannelConnectionBroken() { assertThat(addressesAcceptanceStatus.isOk()).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); // Simulate state transitions for each subchannel individually. - for (Subchannel sc : getSubchannels(loadBalancer)) { + for (Subchannel sc : subchannels.values()) { verify(sc).requestConnection(); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); Status error = Status.UNKNOWN.withDescription("connection broken"); @@ -423,7 +405,8 @@ public void refreshNameResolutionWhenSubchannelConnectionBroken() { deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(helper).refreshNameResolution(); verify(sc, times(2)).requestConnection(); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); } AbstractTestHelper.verifyNoMoreMeaningfulInteractions(helper); @@ -449,8 +432,8 @@ public void pickerLeastRequest() throws Exception { ((LeastRequestLbState) childLbStates.get(i)).getActiveRequests()); } - for (ChildLbState cs : childLbStates) { - deliverSubchannelState(getSubchannel(cs), ConnectivityStateInfo.forNonError(READY)); + for (Subchannel sc : subchannels.values()) { + deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(READY)); } // Capture the active ReadyPicker once all subchannels are READY @@ -460,45 +443,37 @@ public void pickerLeastRequest() throws Exception { ReadyPicker picker = (ReadyPicker) pickerCaptor.getValue(); - assertThat(picker.getChildEags()) - .containsExactlyElementsIn(childLbStates.stream().map(ChildLbState::getEag).toArray()); + assertThat(picker.getChildPickers()).containsExactlyElementsIn( + childLbStates.stream().map(ChildLbState::getCurrentPicker).toArray()); // Make random return 0, then 2 for the sample indexes. when(mockRandom.nextInt(childLbStates.size())).thenReturn(0, 2); PickResult pickResult1 = picker.pickSubchannel(mockArgs); verify(mockRandom, times(choiceCount)).nextInt(childLbStates.size()); - assertEquals(childLbStates.get(0), getChildLbState(pickResult1)); + assertThat(pickResult1.getSubchannel()).isEqualTo(getSubchannel(servers.get(0))); // This simulates sending the actual RPC on the picked channel ClientStreamTracer streamTracer1 = pickResult1.getStreamTracerFactory() .newClientStreamTracer(StreamInfo.newBuilder().build(), new Metadata()); streamTracer1.streamCreated(Attributes.EMPTY, new Metadata()); - assertEquals(1, getChildLbState(pickResult1).getActiveRequests()); + assertEquals(1, ((LeastRequestLbState) childLbStates.get(0)).getActiveRequests()); // For the second pick it should pick the one with lower inFlight. when(mockRandom.nextInt(childLbStates.size())).thenReturn(0, 2); PickResult pickResult2 = picker.pickSubchannel(mockArgs); // Since this is the second pick we expect the total random samples to be choiceCount * 2 verify(mockRandom, times(choiceCount * 2)).nextInt(childLbStates.size()); - assertEquals(childLbStates.get(2), getChildLbState(pickResult2)); + assertThat(pickResult2.getSubchannel()).isEqualTo(getSubchannel(servers.get(2))); // For the third pick we unavoidably pick subchannel with index 1. when(mockRandom.nextInt(childLbStates.size())).thenReturn(1, 1); PickResult pickResult3 = picker.pickSubchannel(mockArgs); verify(mockRandom, times(choiceCount * 3)).nextInt(childLbStates.size()); - assertEquals(childLbStates.get(1), getChildLbState(pickResult3)); + assertThat(pickResult3.getSubchannel()).isEqualTo(getSubchannel(servers.get(1))); // Finally ensure a finished RPC decreases inFlight streamTracer1.streamClosed(Status.OK); - assertEquals(0, getChildLbState(pickResult1).getActiveRequests()); - } - - @Test - public void pickerEmptyList() throws Exception { - SubchannelPicker picker = new EmptyPicker(); - - assertNull(picker.pickSubchannel(mockArgs).getSubchannel()); - assertEquals(Status.OK, picker.pickSubchannel(mockArgs).getStatus()); + assertEquals(0, ((LeastRequestLbState) childLbStates.get(0)).getActiveRequests()); } @Test @@ -578,7 +553,8 @@ public void subchannelStateIsolation() throws Exception { Iterator pickers = pickerCaptor.getAllValues().iterator(); // The picker is incrementally updated as subchannels become READY assertEquals(CONNECTING, stateIterator.next()); - assertThat(pickers.next()).isInstanceOf(EmptyPicker.class); + assertThat(pickers.next().pickSubchannel(mockArgs)) + .isEqualTo(PickResult.withNoResult()); assertEquals(READY, stateIterator.next()); assertThat(getList(pickers.next())).containsExactly(sc1); assertEquals(READY, stateIterator.next()); @@ -609,8 +585,8 @@ public void readyPicker_emptyList() { @Test public void internalPickerComparisons() { - EmptyPicker empty1 = new EmptyPicker(); - EmptyPicker empty2 = new EmptyPicker(); + FixedResultPicker empty1 = new FixedResultPicker(PickResult.withNoResult()); + FixedResultPicker empty2 = new FixedResultPicker(PickResult.withNoResult()); loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); @@ -648,8 +624,8 @@ public void emptyAddresses() { private List getList(SubchannelPicker picker) { if (picker instanceof ReadyPicker) { - return ((ReadyPicker) picker).getChildEags().stream() - .map(this::getSubchannel) + return ((ReadyPicker) picker).getChildPickers().stream() + .map((p) -> p.pickSubchannel(mockArgs).getSubchannel()) .collect(Collectors.toList()); } else { return Collections.emptyList(); diff --git a/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java b/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java index e09066461c4..b8b20248026 100644 --- a/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java @@ -101,6 +101,22 @@ public class LoadBalancerConfigFactoryTest { .build())) .build()) .build(); + + private static final Policy WRR_POLICY_WITH_METRICS = Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setName("backend") + .setTypedConfig( + Any.pack(ClientSideWeightedRoundRobin.newBuilder() + .setBlackoutPeriod(Duration.newBuilder().setSeconds(287).build()) + .setEnableOobLoadReport( + BoolValue.newBuilder().setValue(true).build()) + .setErrorUtilizationPenalty( + FloatValue.newBuilder().setValue(1.75F).build()) + .addMetricNamesForComputingUtilization("foo") + .addMetricNamesForComputingUtilization("bar") + .build())) + .build()) + .build(); private static final String CUSTOM_POLICY_NAME = "myorg.MyCustomLeastRequestPolicy"; private static final String CUSTOM_POLICY_FIELD_KEY = "choiceCount"; private static final double CUSTOM_POLICY_FIELD_VALUE = 2; @@ -130,6 +146,15 @@ public class LoadBalancerConfigFactoryTest { ImmutableMap.of("weighted_round_robin", ImmutableMap.of("blackoutPeriod","287s", "enableOobLoadReport", true, "errorUtilizationPenalty", 1.75F ))))); + + private static final LbConfig VALID_WRR_CONFIG_WITH_METRICS = + new LbConfig("wrr_locality_experimental", + ImmutableMap.of("childPolicy", + ImmutableList.of(ImmutableMap.of("weighted_round_robin", + ImmutableMap.of("blackoutPeriod", "287s", "enableOobLoadReport", true, + "errorUtilizationPenalty", 1.75F, + LoadBalancerConfigFactory.METRIC_NAMES_FOR_COMPUTING_UTILIZATION, + ImmutableList.of("foo", "bar")))))); private static final LbConfig VALID_RING_HASH_CONFIG = new LbConfig("ring_hash_experimental", ImmutableMap.of("minRingSize", (double) RING_HASH_MIN_RING_SIZE, "maxRingSize", (double) RING_HASH_MAX_RING_SIZE)); @@ -165,6 +190,13 @@ public void weightedRoundRobin() throws ResourceInvalidException { assertThat(newLbConfig(cluster, true)).isEqualTo(VALID_WRR_CONFIG); } + @Test + public void weightedRoundRobin_withMetrics() throws ResourceInvalidException { + Cluster cluster = newCluster(buildWrrPolicy(WRR_POLICY_WITH_METRICS)); + + assertThat(newLbConfig(cluster, true)).isEqualTo(VALID_WRR_CONFIG_WITH_METRICS); + } + @Test public void weightedRoundRobin_invalid() throws ResourceInvalidException { Cluster cluster = newCluster(buildWrrPolicy(Policy.newBuilder() diff --git a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java index c11a3a6e0d2..9bdf86132b6 100644 --- a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java @@ -178,11 +178,15 @@ public void cancelled(Context context) { when(backoffPolicy2.nextBackoffNanos()) .thenReturn(TimeUnit.SECONDS.toNanos(2L), TimeUnit.SECONDS.toNanos(20L)); addFakeStatsData(); - lrsClient = new LoadReportClient(loadStatsManager, - GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.createForTest(channel), - NODE, - syncContext, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, - fakeClock.getStopwatchSupplier()); + lrsClient = + new LoadReportClient( + loadStatsManager, + new GrpcXdsTransportFactory(null).createForTest(channel), + NODE, + syncContext, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier()); syncContext.execute(new Runnable() { @Override public void run() { diff --git a/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java b/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java index ecc0112a2e0..0499bafdb23 100644 --- a/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java +++ b/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java @@ -107,6 +107,7 @@ protected LoadBalancer delegate() { return delegateLb; } + @Deprecated @Override public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { MetadataLoadBalancerConfig config @@ -114,6 +115,14 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { helper.setMetadata(config.metadataKey, config.metadataValue); delegateLb.handleResolvedAddresses(resolvedAddresses); } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + MetadataLoadBalancerConfig config + = (MetadataLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + helper.setMetadata(config.metadataKey, config.metadataValue); + return delegateLb.acceptResolvedAddresses(resolvedAddresses); + } } /** diff --git a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerProviderTest.java index 9f0b5f9578e..37ea24b2aa9 100644 --- a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerProviderTest.java @@ -16,6 +16,7 @@ package io.grpc.xds; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; import com.google.common.collect.ImmutableList; @@ -26,17 +27,13 @@ import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig; import java.util.List; import java.util.Map; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** Tests for {@link PriorityLoadBalancerProvider}. */ @RunWith(JUnit4.class) public class PriorityLoadBalancerProviderTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @SuppressWarnings("ExpectedExceptionChecker") @Test @@ -48,8 +45,8 @@ public void priorityLbConfig_emptyPriorities() { newChildConfig(mock(LoadBalancerProvider.class), null), true)); List priorities = ImmutableList.of(); - thrown.expect(IllegalArgumentException.class); - new PriorityLbConfig(childConfigs, priorities); + assertThrows(IllegalArgumentException.class, + () -> new PriorityLbConfig(childConfigs, priorities)); } @SuppressWarnings("ExpectedExceptionChecker") @@ -62,8 +59,8 @@ public void priorityLbConfig_missingChildConfig() { newChildConfig(mock(LoadBalancerProvider.class), null), true)); List priorities = ImmutableList.of("p0", "p1"); - thrown.expect(IllegalArgumentException.class); - new PriorityLbConfig(childConfigs, priorities); + assertThrows(IllegalArgumentException.class, + () -> new PriorityLbConfig(childConfigs, priorities)); } private Object newChildConfig(LoadBalancerProvider provider, Object config) { diff --git a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java index fafcd4d674a..beb568be9ce 100644 --- a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java @@ -28,11 +28,13 @@ import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -69,6 +71,7 @@ import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Captor; +import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -97,6 +100,8 @@ public void uncaughtException(Thread t, Throwable e) { public LoadBalancer newLoadBalancer(Helper helper) { fooHelpers.add(helper); LoadBalancer childBalancer = mock(LoadBalancer.class); + when(childBalancer.acceptResolvedAddresses(any(ResolvedAddresses.class))) + .thenReturn(Status.OK); fooBalancers.add(childBalancer); return childBalancer; } @@ -107,6 +112,8 @@ public LoadBalancer newLoadBalancer(Helper helper) { @Override public LoadBalancer newLoadBalancer(Helper helper) { LoadBalancer childBalancer = mock(LoadBalancer.class); + when(childBalancer.acceptResolvedAddresses(any(ResolvedAddresses.class))) + .thenReturn(Status.OK); barBalancers.add(childBalancer); return childBalancer; } @@ -141,7 +148,7 @@ public void tearDown() { } @Test - public void handleResolvedAddresses() { + public void acceptResolvedAddresses() { SocketAddress socketAddress = new InetSocketAddress(8080); EquivalentAddressGroup eag = new EquivalentAddressGroup(socketAddress); eag = AddressFilter.setPathFilter(eag, ImmutableList.of("p1")); @@ -162,16 +169,17 @@ public void handleResolvedAddresses() { ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1, "p2", priorityChildConfig2), ImmutableList.of("p0", "p1", "p2")); - priorityLb.handleResolvedAddresses( + Status status = priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(addresses) .setAttributes(attributes) .setLoadBalancingPolicyConfig(priorityLbConfig) .build()); + assertThat(status.getCode()).isEqualTo(Status.Code.OK); assertThat(fooBalancers).hasSize(1); assertThat(barBalancers).isEmpty(); LoadBalancer fooBalancer0 = Iterables.getOnlyElement(fooBalancers); - verify(fooBalancer0).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(fooBalancer0).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); ResolvedAddresses addressesReceived = resolvedAddressesCaptor.getValue(); assertThat(addressesReceived.getAddresses()).isEmpty(); assertThat(addressesReceived.getAttributes()).isEqualTo(attributes); @@ -182,7 +190,7 @@ public void handleResolvedAddresses() { assertThat(fooBalancers).hasSize(1); assertThat(barBalancers).hasSize(1); LoadBalancer barBalancer0 = Iterables.getOnlyElement(barBalancers); - verify(barBalancer0).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(barBalancer0).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); addressesReceived = resolvedAddressesCaptor.getValue(); assertThat(Iterables.getOnlyElement(addressesReceived.getAddresses()).getAddresses()) .containsExactly(socketAddress); @@ -194,7 +202,7 @@ public void handleResolvedAddresses() { assertThat(fooBalancers).hasSize(2); assertThat(barBalancers).hasSize(1); LoadBalancer fooBalancer1 = Iterables.getLast(fooBalancers); - verify(fooBalancer1).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(fooBalancer1).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); addressesReceived = resolvedAddressesCaptor.getValue(); assertThat(addressesReceived.getAddresses()).isEmpty(); assertThat(addressesReceived.getAttributes()).isEqualTo(attributes); @@ -211,14 +219,15 @@ public void handleResolvedAddresses() { ImmutableMap.of("p1", new PriorityChildConfig(newChildConfig(barLbProvider, newBarConfig), true)), ImmutableList.of("p1")); - priorityLb.handleResolvedAddresses( + status = priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(newAddresses) .setLoadBalancingPolicyConfig(newPriorityLbConfig) .build()); + assertThat(status.getCode()).isEqualTo(Status.Code.OK); assertThat(fooBalancers).hasSize(2); assertThat(barBalancers).hasSize(1); - verify(barBalancer0, times(2)).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(barBalancer0, times(2)).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); addressesReceived = resolvedAddressesCaptor.getValue(); assertThat(Iterables.getOnlyElement(addressesReceived.getAddresses()).getAddresses()) .containsExactly(newSocketAddress); @@ -232,6 +241,60 @@ public void handleResolvedAddresses() { verify(barBalancer0, never()).shutdown(); } + @Test + public void acceptResolvedAddresses_propagatesChildFailures() { + LoadBalancerProvider lbProvider = new CannedLoadBalancer.Provider(); + CannedLoadBalancer.Config internalTf = new CannedLoadBalancer.Config( + Status.INTERNAL, TRANSIENT_FAILURE); + CannedLoadBalancer.Config okTf = new CannedLoadBalancer.Config(Status.OK, TRANSIENT_FAILURE); + ResolvedAddresses resolvedAddresses = ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.EMPTY) + .build(); + + // tryNewPriority() propagates status + Status status = priorityLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new PriorityLbConfig( + ImmutableMap.of( + "p0", newPriorityChildConfig(lbProvider, internalTf, true)), + ImmutableList.of("p0"))) + .build()); + assertThat(status.getCode()).isNotEqualTo(Status.Code.OK); + + // Updating a child propagates status + status = priorityLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new PriorityLbConfig( + ImmutableMap.of( + "p0", newPriorityChildConfig(lbProvider, internalTf, true)), + ImmutableList.of("p0"))) + .build()); + assertThat(status.getCode()).isNotEqualTo(Status.Code.OK); + + // A single pre-existing child failure propagates + status = priorityLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new PriorityLbConfig( + ImmutableMap.of( + "p0", newPriorityChildConfig(lbProvider, okTf, true), + "p1", newPriorityChildConfig(lbProvider, okTf, true), + "p2", newPriorityChildConfig(lbProvider, okTf, true)), + ImmutableList.of("p0", "p1", "p2"))) + .build()); + assertThat(status.getCode()).isEqualTo(Status.Code.OK); + status = priorityLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new PriorityLbConfig( + ImmutableMap.of( + "p0", newPriorityChildConfig(lbProvider, okTf, true), + "p1", newPriorityChildConfig(lbProvider, internalTf, true), + "p2", newPriorityChildConfig(lbProvider, okTf, true)), + ImmutableList.of("p0", "p1", "p2"))) + .build()); + assertThat(status.getCode()).isNotEqualTo(Status.Code.OK); + } + @Test public void handleNameResolutionError() { Object fooConfig0 = new Object(); @@ -243,7 +306,7 @@ public void handleNameResolutionError() { PriorityLbConfig priorityLbConfig = new PriorityLbConfig(ImmutableMap.of("p0", priorityChildConfig0), ImmutableList.of("p0")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -255,7 +318,7 @@ public void handleNameResolutionError() { priorityLbConfig = new PriorityLbConfig(ImmutableMap.of("p1", priorityChildConfig1), ImmutableList.of("p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -286,7 +349,7 @@ public void typicalPriorityFailOverFlow() { ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1, "p2", priorityChildConfig2, "p3", priorityChildConfig3), ImmutableList.of("p0", "p1", "p2", "p3")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -315,6 +378,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { assertThat(fooBalancers).hasSize(2); assertThat(fooHelpers).hasSize(2); LoadBalancer balancer1 = Iterables.getLast(fooBalancers); + Helper helper1 = Iterables.getLast(fooHelpers); // p1 timeout, and fails over to p2 fakeClock.forwardTime(10, TimeUnit.SECONDS); @@ -362,14 +426,20 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { LoadBalancer balancer3 = Iterables.getLast(fooBalancers); Helper helper3 = Iterables.getLast(fooHelpers); - // p3 timeout then the channel should go to TRANSIENT_FAILURE + // p3 timeout then the channel should stay in CONNECTING fakeClock.forwardTime(10, TimeUnit.SECONDS); - assertCurrentPickerReturnsError(Status.Code.UNAVAILABLE, "timeout"); + assertCurrentPicker(CONNECTING, PickResult.withNoResult()); - // p3 fails then the picker should have error status updated + // p3 fails then the picker should still be waiting on p1 helper3.updateBalancingState( TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(Status.DATA_LOSS.withDescription("foo")))); + assertCurrentPicker(CONNECTING, PickResult.withNoResult()); + + // p1 fails then the picker should have error status updated to p3 + helper1.updateBalancingState( + TRANSIENT_FAILURE, + new FixedResultPicker(PickResult.withError(Status.DATA_LOSS.withDescription("bar")))); assertCurrentPickerReturnsError(Status.Code.DATA_LOSS, "foo"); // p2 gets back to READY @@ -419,7 +489,7 @@ public void idleToConnectingDoesNotTriggerFailOver() { new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), ImmutableList.of("p0", "p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -455,14 +525,13 @@ public void connectingResetFailOverIfSeenReadyOrIdleSinceTransientFailure() { new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), ImmutableList.of("p0", "p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) .build()); // Nothing important about this verify, other than to provide a baseline - verify(helper, times(2)) - .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); + verify(helper).updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); assertThat(fooBalancers).hasSize(1); assertThat(fooHelpers).hasSize(1); Helper helper0 = Iterables.getOnlyElement(fooHelpers); @@ -478,7 +547,7 @@ public void connectingResetFailOverIfSeenReadyOrIdleSinceTransientFailure() { helper0.updateBalancingState( CONNECTING, EMPTY_PICKER); - verify(helper, times(3)) + verify(helper, times(2)) .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); // failover happens @@ -487,6 +556,55 @@ public void connectingResetFailOverIfSeenReadyOrIdleSinceTransientFailure() { assertThat(fooHelpers).hasSize(2); } + @Test + public void failoverTimerNotRestartedOnDupConnecting() { + InOrder inOrder = inOrder(helper); + PriorityChildConfig priorityChildConfig0 = + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); + PriorityChildConfig priorityChildConfig1 = + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); + PriorityLbConfig priorityLbConfig = + new PriorityLbConfig( + ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), + ImmutableList.of("p0", "p1")); + priorityLb.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .build()); + // Nothing important about this verify, other than to provide a baseline + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); + assertThat(fooBalancers).hasSize(1); + assertThat(fooHelpers).hasSize(1); + Helper helper0 = Iterables.getOnlyElement(fooHelpers); + + // Cause seenReadyOrIdleSinceTransientFailure = true + helper0.updateBalancingState(IDLE, EMPTY_PICKER); + inOrder.verify(helper) + .updateBalancingState(eq(IDLE), pickerReturns(PickResult.withNoResult())); + helper0.updateBalancingState(CONNECTING, EMPTY_PICKER); + + // p0 keeps repeating CONNECTING, failover happens + fakeClock.forwardTime(5, TimeUnit.SECONDS); + helper0.updateBalancingState(CONNECTING, EMPTY_PICKER); + fakeClock.forwardTime(5, TimeUnit.SECONDS); + assertThat(fooBalancers).hasSize(2); + assertThat(fooHelpers).hasSize(2); + inOrder.verify(helper, times(2)) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); + Helper helper1 = Iterables.getLast(fooHelpers); + + // p0 keeps repeating CONNECTING, no reset of failover timer + helper1.updateBalancingState(IDLE, EMPTY_PICKER); // Stop timer for p1 + inOrder.verify(helper) + .updateBalancingState(eq(IDLE), pickerReturns(PickResult.withNoResult())); + helper0.updateBalancingState(CONNECTING, EMPTY_PICKER); + fakeClock.forwardTime(10, TimeUnit.SECONDS); + inOrder.verify(helper, never()) + .updateBalancingState(eq(CONNECTING), any()); + } + @Test public void readyToConnectDoesNotFailOverButUpdatesPicker() { PriorityChildConfig priorityChildConfig0 = @@ -497,7 +615,7 @@ public void readyToConnectDoesNotFailOverButUpdatesPicker() { new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), ImmutableList.of("p0", "p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -530,7 +648,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { // resolution update without priority change does not trigger failover Attributes.Key fooKey = Attributes.Key.create("fooKey"); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -559,7 +677,7 @@ public void typicalPriorityFailOverFlowWithIdleUpdate() { ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1, "p2", priorityChildConfig2, "p3", priorityChildConfig3), ImmutableList.of("p0", "p1", "p2", "p3")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -582,6 +700,7 @@ public void typicalPriorityFailOverFlowWithIdleUpdate() { assertThat(fooBalancers).hasSize(2); assertThat(fooHelpers).hasSize(2); LoadBalancer balancer1 = Iterables.getLast(fooBalancers); + Helper helper1 = Iterables.getLast(fooHelpers); // p1 timeout, and fails over to p2 fakeClock.forwardTime(10, TimeUnit.SECONDS); @@ -617,14 +736,20 @@ public void typicalPriorityFailOverFlowWithIdleUpdate() { LoadBalancer balancer3 = Iterables.getLast(fooBalancers); Helper helper3 = Iterables.getLast(fooHelpers); - // p3 timeout then the channel should go to TRANSIENT_FAILURE + // p3 timeout then the channel should stay in CONNECTING fakeClock.forwardTime(10, TimeUnit.SECONDS); - assertCurrentPickerReturnsError(Status.Code.UNAVAILABLE, "timeout"); + assertCurrentPicker(CONNECTING, PickResult.withNoResult()); - // p3 fails then the picker should have error status updated + // p3 fails then the picker should still be waiting on p1 helper3.updateBalancingState( TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(Status.DATA_LOSS.withDescription("foo")))); + assertCurrentPicker(CONNECTING, PickResult.withNoResult()); + + // p1 fails then the picker should have error status updated to p3 + helper1.updateBalancingState( + TRANSIENT_FAILURE, + new FixedResultPicker(PickResult.withError(Status.DATA_LOSS.withDescription("bar")))); assertCurrentPickerReturnsError(Status.Code.DATA_LOSS, "foo"); // p2 gets back to IDLE @@ -652,6 +777,55 @@ public void typicalPriorityFailOverFlowWithIdleUpdate() { verify(balancer3).shutdown(); } + @Test + public void failover_propagatesChildFailures() { + LoadBalancerProvider lbProvider = new CannedLoadBalancer.Provider(); + ResolvedAddresses resolvedAddresses = ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.EMPTY) + .build(); + + Status status = priorityLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new PriorityLbConfig( + ImmutableMap.of( + "p0", newPriorityChildConfig( + lbProvider, new CannedLoadBalancer.Config(Status.OK, TRANSIENT_FAILURE), true), + "p1", newPriorityChildConfig( + lbProvider, new CannedLoadBalancer.Config(Status.INTERNAL, CONNECTING), true)), + ImmutableList.of("p0", "p1"))) + .build()); + // Since P1's activation wasn't noticed by the result status, it triggered name resolution + assertThat(status.getCode()).isEqualTo(Status.Code.OK); + verify(helper).refreshNameResolution(); + } + + @Test + public void failoverTimer_propagatesChildFailures() { + LoadBalancerProvider lbProvider = new CannedLoadBalancer.Provider(); + ResolvedAddresses resolvedAddresses = ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.EMPTY) + .build(); + + Status status = priorityLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new PriorityLbConfig( + ImmutableMap.of( + "p0", newPriorityChildConfig( + lbProvider, new CannedLoadBalancer.Config(Status.OK, CONNECTING), true), + "p1", newPriorityChildConfig( + lbProvider, new CannedLoadBalancer.Config(Status.INTERNAL, CONNECTING), true)), + ImmutableList.of("p0", "p1"))) + .build()); + assertThat(status.getCode()).isEqualTo(Status.Code.OK); + + // P1's activation will refresh name resolution + verify(helper, never()).refreshNameResolution(); + fakeClock.forwardTime(10, TimeUnit.SECONDS); + verify(helper).refreshNameResolution(); + } + @Test public void bypassReresolutionRequestsIfConfiged() { PriorityChildConfig priorityChildConfig0 = @@ -662,7 +836,7 @@ public void bypassReresolutionRequestsIfConfiged() { new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), ImmutableList.of("p0", "p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -690,12 +864,12 @@ public void raceBetweenShutdownAndChildLbBalancingStateUpdate() { new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), ImmutableList.of("p0", "p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) .build()); - verify(helper, times(2)).updateBalancingState(eq(CONNECTING), isA(SubchannelPicker.class)); + verify(helper).updateBalancingState(eq(CONNECTING), isA(SubchannelPicker.class)); // LB shutdown and subchannel state change can happen simultaneously. If shutdown runs first, // any further balancing state update should be ignored. @@ -717,7 +891,7 @@ public void noDuplicateOverallBalancingStateUpdate() { new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0), ImmutableList.of("p0")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -727,13 +901,13 @@ public void noDuplicateOverallBalancingStateUpdate() { new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), ImmutableList.of("p0", "p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) .build()); - verify(helper, times(6)).updateBalancingState(any(), any()); + verify(helper, times(4)).updateBalancingState(any(), any()); } private void assertLatestConnectivityState(ConnectivityState expectedState) { @@ -754,21 +928,28 @@ private void assertCurrentPickerReturnsError( } private void assertCurrentPickerPicksSubchannel(Subchannel expectedSubchannelToPick) { - assertLatestConnectivityState(READY); - PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); - assertThat(pickResult.getSubchannel()).isEqualTo(expectedSubchannelToPick); + assertCurrentPicker(READY, PickResult.withSubchannel(expectedSubchannelToPick)); } private void assertCurrentPickerIsBufferPicker() { - assertLatestConnectivityState(IDLE); + assertCurrentPicker(IDLE, PickResult.withNoResult()); + } + + private void assertCurrentPicker(ConnectivityState state, PickResult result) { + assertLatestConnectivityState(state); PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); - assertThat(pickResult).isEqualTo(PickResult.withNoResult()); + assertThat(pickResult).isEqualTo(result); } private Object newChildConfig(LoadBalancerProvider provider, Object config) { return GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(provider, config); } + private PriorityChildConfig newPriorityChildConfig( + LoadBalancerProvider provider, Object config, boolean ignoreRefresh) { + return new PriorityChildConfig(newChildConfig(provider, config), ignoreRefresh); + } + private static class FakeLoadBalancerProvider extends LoadBalancerProvider { @Override @@ -801,9 +982,10 @@ static class FakeLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { helper.updateBalancingState( TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(Status.INTERNAL))); + return Status.OK; } @Override @@ -814,4 +996,47 @@ public void handleNameResolutionError(Status error) { public void shutdown() { } } + + static final class CannedLoadBalancer extends LoadBalancer { + private final Helper helper; + + private CannedLoadBalancer(Helper helper) { + this.helper = helper; + } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses addresses) { + Config config = (Config) addresses.getLoadBalancingPolicyConfig(); + helper.updateBalancingState( + config.state, new FixedResultPicker(PickResult.withError(Status.INTERNAL))); + return config.resolvedAddressesResult; + } + + @Override + public void handleNameResolutionError(Status status) {} + + @Override + public void shutdown() {} + + static final class Provider extends StandardLoadBalancerProvider { + public Provider() { + super("echo"); + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return new CannedLoadBalancer(helper); + } + } + + static final class Config { + final Status resolvedAddressesResult; + final ConnectivityState state; + + public Config(Status resolvedAddressesResult, ConnectivityState state) { + this.resolvedAddressesResult = resolvedAddressesResult; + this.state = state; + } + } + } } diff --git a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java index 29af01b222f..334e159dd1d 100644 --- a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java @@ -78,6 +78,15 @@ public class RbacFilterTest { private static final String PATH = "auth"; private static final StringMatcher STRING_MATCHER = StringMatcher.newBuilder().setExact("/" + PATH).setIgnoreCase(true).build(); + private static final RbacFilter.Provider FILTER_PROVIDER = new RbacFilter.Provider(); + + private final String name = "theFilterName"; + + @Test + public void filterType_serverOnly() { + assertThat(FILTER_PROVIDER.isClientFilter()).isFalse(); + assertThat(FILTER_PROVIDER.isServerFilter()).isTrue(); + } @Test @SuppressWarnings({"unchecked", "deprecation"}) @@ -219,14 +228,15 @@ public void headerParser_headerName() { @SuppressWarnings("unchecked") public void compositeRules() { MetadataMatcher metadataMatcher = MetadataMatcher.newBuilder().build(); + @SuppressWarnings("deprecation") + Permission permissionMetadata = Permission.newBuilder().setMetadata(metadataMatcher).build(); List permissionList = Arrays.asList( Permission.newBuilder().setOrRules(Permission.Set.newBuilder().addRules( - Permission.newBuilder().setMetadata(metadataMatcher).build() - ).build()).build()); + permissionMetadata).build()).build()); + @SuppressWarnings("deprecation") + Principal principalMetadata = Principal.newBuilder().setMetadata(metadataMatcher).build(); List principalList = Arrays.asList( - Principal.newBuilder().setNotId( - Principal.newBuilder().setMetadata(metadataMatcher).build() - ).build()); + Principal.newBuilder().setNotId(principalMetadata).build()); ConfigOrError result = parse(permissionList, principalList); assertThat(result.errorDetail).isNull(); assertThat(result.config).isInstanceOf(RbacConfig.class); @@ -251,7 +261,7 @@ public void testAuthorizationInterceptor() { OrMatcher.create(AlwaysTrueMatcher.INSTANCE)); AuthConfig authconfig = AuthConfig.create(Collections.singletonList(policyMatcher), GrpcAuthorizationEngine.Action.ALLOW); - new RbacFilter().buildServerInterceptor(RbacConfig.create(authconfig), null) + FILTER_PROVIDER.newInstance(name).buildServerInterceptor(RbacConfig.create(authconfig), null) .interceptCall(mockServerCall, new Metadata(), mockHandler); verify(mockHandler, never()).startCall(eq(mockServerCall), any(Metadata.class)); ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); @@ -263,7 +273,7 @@ public void testAuthorizationInterceptor() { authconfig = AuthConfig.create(Collections.singletonList(policyMatcher), GrpcAuthorizationEngine.Action.DENY); - new RbacFilter().buildServerInterceptor(RbacConfig.create(authconfig), null) + FILTER_PROVIDER.newInstance(name).buildServerInterceptor(RbacConfig.create(authconfig), null) .interceptCall(mockServerCall, new Metadata(), mockHandler); verify(mockHandler).startCall(eq(mockServerCall), any(Metadata.class)); } @@ -289,7 +299,7 @@ public void handleException() { .putPolicies("policy-name", Policy.newBuilder().setCondition(Expr.newBuilder().build()).build()) .build()).build(); - result = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + result = FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto)); assertThat(result.errorDetail).isNotNull(); } @@ -311,10 +321,10 @@ public void overrideConfig() { RbacConfig original = RbacConfig.create(authconfig); RBACPerRoute rbacPerRoute = RBACPerRoute.newBuilder().build(); - RbacConfig override = - new RbacFilter().parseFilterConfigOverride(Any.pack(rbacPerRoute)).config; + RbacConfig override = FILTER_PROVIDER.parseFilterConfigOverride(Any.pack(rbacPerRoute)).config; assertThat(override).isEqualTo(RbacConfig.create(null)); - ServerInterceptor interceptor = new RbacFilter().buildServerInterceptor(original, override); + ServerInterceptor interceptor = + FILTER_PROVIDER.newInstance(name).buildServerInterceptor(original, override); assertThat(interceptor).isNull(); policyMatcher = PolicyMatcher.create("policy-matcher-override", @@ -324,7 +334,7 @@ public void overrideConfig() { GrpcAuthorizationEngine.Action.ALLOW); override = RbacConfig.create(authconfig); - new RbacFilter().buildServerInterceptor(original, override) + FILTER_PROVIDER.newInstance(name).buildServerInterceptor(original, override) .interceptCall(mockServerCall, new Metadata(), mockHandler); verify(mockHandler).startCall(eq(mockServerCall), any(Metadata.class)); verify(mockServerCall).getAttributes(); @@ -336,22 +346,22 @@ public void ignoredConfig() { Message rawProto = io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC.newBuilder() .setRules(RBAC.newBuilder().setAction(Action.LOG) .putPolicies("policy-name", Policy.newBuilder().build()).build()).build(); - ConfigOrError result = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + ConfigOrError result = FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto)); assertThat(result.config).isEqualTo(RbacConfig.create(null)); } @Test public void testOrderIndependenceOfPolicies() { Message rawProto = buildComplexRbac(ImmutableList.of(1, 2, 3, 4, 5, 6), true); - ConfigOrError ascFirst = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + ConfigOrError ascFirst = FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto)); rawProto = buildComplexRbac(ImmutableList.of(1, 2, 3, 4, 5, 6), false); - ConfigOrError ascLast = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + ConfigOrError ascLast = FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto)); assertThat(ascFirst.config).isEqualTo(ascLast.config); rawProto = buildComplexRbac(ImmutableList.of(6, 5, 4, 3, 2, 1), true); - ConfigOrError decFirst = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + ConfigOrError decFirst = FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto)); assertThat(ascFirst.config).isEqualTo(decFirst.config); } @@ -373,14 +383,14 @@ private MethodDescriptor.Builder method() { private ConfigOrError parse(List permissionList, List principalList) { - return RbacFilter.parseRbacConfig(buildRbac(permissionList, principalList)); + return RbacFilter.Provider.parseRbacConfig(buildRbac(permissionList, principalList)); } private ConfigOrError parseRaw(List permissionList, List principalList) { Message rawProto = buildRbac(permissionList, principalList); Any proto = Any.pack(rawProto); - return new RbacFilter().parseFilterConfig(proto); + return FILTER_PROVIDER.parseFilterConfig(proto); } private io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC buildRbac( @@ -448,6 +458,6 @@ private ConfigOrError parseOverride(List permissionList, RBACPerRoute rbacPerRoute = RBACPerRoute.newBuilder().setRbac( buildRbac(permissionList, principalList)).build(); Any proto = Any.pack(rbacPerRoute); - return new RbacFilter().parseFilterConfigOverride(proto); + return FILTER_PROVIDER.parseFilterConfigOverride(proto); } } diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java index 87615a125c0..66c9c5c537e 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java @@ -42,6 +42,8 @@ @RunWith(JUnit4.class) public class RingHashLoadBalancerProviderTest { private static final String AUTHORITY = "foo.googleapis.com"; + private static final String GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY = + "GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY"; private final SynchronizationContext syncContext = new SynchronizationContext( new UncaughtExceptionHandler() { @@ -81,6 +83,7 @@ public void parseLoadBalancingConfig_valid() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(10L); assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEmpty(); } @Test @@ -92,6 +95,7 @@ public void parseLoadBalancingConfig_missingRingSize_useDefaults() throws IOExce RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(RingHashLoadBalancerProvider.DEFAULT_MIN_RING_SIZE); assertThat(config.maxRingSize).isEqualTo(RingHashLoadBalancerProvider.DEFAULT_MAX_RING_SIZE); + assertThat(config.requestHashHeader).isEmpty(); } @Test @@ -102,7 +106,7 @@ public void parseLoadBalancingConfig_invalid_negativeSize() throws IOException { assertThat(configOrError.getError()).isNotNull(); assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) - .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); + .isEqualTo("Invalid 'minRingSize'/'maxRingSize'"); } @Test @@ -113,7 +117,7 @@ public void parseLoadBalancingConfig_invalid_minGreaterThanMax() throws IOExcept assertThat(configOrError.getError()).isNotNull(); assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) - .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); + .isEqualTo("Invalid 'minRingSize'/'maxRingSize'"); } @Test @@ -127,6 +131,7 @@ public void parseLoadBalancingConfig_ringTooLargeUsesCap() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(10); assertThat(config.maxRingSize).isEqualTo(RingHashOptions.DEFAULT_RING_SIZE_CAP); + assertThat(config.requestHashHeader).isEmpty(); } @Test @@ -142,6 +147,7 @@ public void parseLoadBalancingConfig_ringCapCanBeRaised() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); assertThat(config.maxRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -159,6 +165,7 @@ public void parseLoadBalancingConfig_ringCapIsClampedTo8M() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); assertThat(config.maxRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -176,6 +183,7 @@ public void parseLoadBalancingConfig_ringCapCanBeLowered() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(1); assertThat(config.maxRingSize).isEqualTo(1); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -193,6 +201,7 @@ public void parseLoadBalancingConfig_ringCapLowerLimitIs1() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(1); assertThat(config.maxRingSize).isEqualTo(1); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -205,7 +214,7 @@ public void parseLoadBalancingConfig_zeroMinRingSize() throws IOException { assertThat(configOrError.getError()).isNotNull(); assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) - .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); + .isEqualTo("Invalid 'minRingSize'/'maxRingSize'"); } @Test @@ -216,7 +225,60 @@ public void parseLoadBalancingConfig_minRingSizeGreaterThanMaxRingSize() throws assertThat(configOrError.getError()).isNotNull(); assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) - .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); + .isEqualTo("Invalid 'minRingSize'/'maxRingSize'"); + } + + @Test + public void parseLoadBalancingConfig_requestHashHeaderIgnoredWhenEnvVarNotSet() + throws IOException { + String lbConfig = + "{\"minRingSize\" : 10, \"maxRingSize\" : 100, \"requestHashHeader\" : \"dummy-hash\"}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10L); + assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEmpty(); + } + + @Test + public void parseLoadBalancingConfig_requestHashHeaderSetWhenEnvVarSet() throws IOException { + System.setProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY, "true"); + try { + String lbConfig = + "{\"minRingSize\" : 10, \"maxRingSize\" : 100, \"requestHashHeader\" : \"dummy-hash\"}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10L); + assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEqualTo("dummy-hash"); + assertThat(config.toString()).contains("minRingSize=10"); + assertThat(config.toString()).contains("maxRingSize=100"); + assertThat(config.toString()).contains("requestHashHeader=dummy-hash"); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY); + } + } + + @Test + public void parseLoadBalancingConfig_requestHashHeaderUnsetWhenEnvVarSet_useDefaults() + throws IOException { + System.setProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY, "true"); + try { + String lbConfig = "{\"minRingSize\" : 10, \"maxRingSize\" : 100}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10L); + assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEmpty(); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY); + } } @SuppressWarnings("unchecked") diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java index 047ba71bbe0..b515ed81158 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java @@ -23,7 +23,6 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.util.MultiChildLoadBalancer.IS_PETIOLE_POLICY; import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.DO_NOT_RESET_HELPER; import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.DO_NOT_VERIFY; import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.RESET_SUBCHANNEL_MOCKS; @@ -43,11 +42,13 @@ import com.google.common.collect.Iterables; import com.google.common.primitives.UnsignedInteger; +import com.google.common.testing.EqualsTester; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickDetailsConsumer; @@ -62,9 +63,11 @@ import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; import io.grpc.internal.PickFirstLoadBalancerProvider; +import io.grpc.internal.PickFirstLoadBalancerProviderAccessor; import io.grpc.internal.PickSubchannelArgsImpl; import io.grpc.testing.TestMethodDescriptors; import io.grpc.util.AbstractTestHelper; +import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import java.lang.Thread.UncaughtExceptionHandler; @@ -74,8 +77,11 @@ import java.util.Collections; import java.util.Deque; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Random; +import java.util.Set; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -93,6 +99,9 @@ @RunWith(JUnit4.class) public class RingHashLoadBalancerTest { private static final String AUTHORITY = "foo.googleapis.com"; + private static final String CUSTOM_REQUEST_HASH_HEADER = "custom-request-hash-header"; + private static final Metadata.Key CUSTOM_METADATA_KEY = + Metadata.Key.of(CUSTOM_REQUEST_HASH_HEADER, Metadata.ASCII_STRING_MARSHALLER); private static final Attributes.Key CUSTOM_KEY = Attributes.Key.create("custom-key"); private static final ConnectivityStateInfo CSI_CONNECTING = ConnectivityStateInfo.forNonError(CONNECTING); @@ -115,6 +124,7 @@ public void uncaughtException(Thread t, Throwable e) { @Captor private ArgumentCaptor pickerCaptor; private RingHashLoadBalancer loadBalancer; + private boolean defaultNewPickFirst = PickFirstLoadBalancerProvider.isEnabledNewPickFirst(); @Before public void setUp() { @@ -126,6 +136,7 @@ public void setUp() { @After public void tearDown() { + PickFirstLoadBalancerProviderAccessor.setEnableNewPickFirst(defaultNewPickFirst); loadBalancer.shutdown(); for (Subchannel subchannel : subchannels.values()) { verify(subchannel).shutdown(); @@ -135,7 +146,7 @@ public void tearDown() { @Test public void subchannelLazyConnectUntilPicked() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1); // one server Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() @@ -150,7 +161,8 @@ public void subchannelLazyConnectUntilPicked() { assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getSubchannel()).isNull(); Subchannel subchannel = Iterables.getOnlyElement(subchannels.values()); - int expectedTimes = PickFirstLoadBalancerProvider.isEnabledHappyEyeballs() ? 1 : 2; + int expectedTimes = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() + && !PickFirstLoadBalancerProvider.isEnabledHappyEyeballs() ? 1 : 2; verify(subchannel, times(expectedTimes)).requestConnection(); verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); @@ -168,7 +180,7 @@ public void subchannelLazyConnectUntilPicked() { @Test public void subchannelNotAutoReconnectAfterReenteringIdle() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1); // one server Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() @@ -176,15 +188,15 @@ public void subchannelNotAutoReconnectAfterReenteringIdle() { assertThat(addressesAcceptanceStatus.isOk()).isTrue(); verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); - ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); - assertThat(subchannels.get(Collections.singletonList(childLbState.getEag()))).isNull(); + assertThat(subchannels).isEmpty(); // Picking subchannel triggers connection. PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); pickerCaptor.getValue().pickSubchannel(args); - Subchannel subchannel = subchannels.get(Collections.singletonList(childLbState.getEag())); + Subchannel subchannel = subchannels.get(Collections.singletonList(servers.get(0))); InOrder inOrder = Mockito.inOrder(helper, subchannel); - int expectedTimes = PickFirstLoadBalancerProvider.isEnabledHappyEyeballs() ? 1 : 2; + int expectedTimes = PickFirstLoadBalancerProvider.isEnabledHappyEyeballs() + || !PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 2 : 1; inOrder.verify(subchannel, times(expectedTimes)).requestConnection(); deliverSubchannelState(subchannel, CSI_READY); inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); @@ -199,7 +211,7 @@ public void subchannelNotAutoReconnectAfterReenteringIdle() { @Test public void aggregateSubchannelStates_connectingReadyIdleFailure() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1); InOrder inOrder = Mockito.inOrder(helper); @@ -243,13 +255,13 @@ public void aggregateSubchannelStates_connectingReadyIdleFailure() { inOrder.verify(helper).refreshNameResolution(); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any()); } - verifyConnection(0); + verifyConnection(1); } private void verifyConnection(int times) { for (int i = 0; i < times; i++) { Subchannel connectOnce = connectionRequestedQueue.poll(); - assertWithMessage("Null connection is at (%s) of (%s)", i, times) + assertWithMessage("Expected %s new connections, but found %s", times, i) .that(connectOnce).isNotNull(); clearInvocations(connectOnce); } @@ -258,7 +270,7 @@ private void verifyConnection(int times) { @Test public void aggregateSubchannelStates_allSubchannelsInTransientFailure() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1, 1, 1); List subChannelList = initializeLbSubchannels(config, servers, STAY_IN_CONNECTING); @@ -316,7 +328,7 @@ private void refreshInvokedAndUpdateBS(InOrder inOrder, ConnectivityState state) @Test public void ignoreShutdownSubchannelStateChange() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -332,7 +344,7 @@ public void ignoreShutdownSubchannelStateChange() { @Test public void deterministicPickWithHostsPartiallyRemoved() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1, 1, 1, 1); initializeLbSubchannels(config, servers); InOrder inOrder = Mockito.inOrder(helper); @@ -372,7 +384,7 @@ public void deterministicPickWithHostsPartiallyRemoved() { @Test public void deterministicPickWithNewHostsAdded() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1); // server0 and server1 initializeLbSubchannels(config, servers, DO_NOT_VERIFY, DO_NOT_RESET_HELPER); @@ -404,6 +416,139 @@ public void deterministicPickWithNewHostsAdded() { inOrder.verifyNoMoreInteractions(); } + @Test + public void deterministicPickWithRequestHashHeader_oneHeaderValue() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, CUSTOM_REQUEST_HASH_HEADER); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + InOrder inOrder = Mockito.inOrder(helper); + + // Bring all subchannels to READY. + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, CSI_READY); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + } + + // Pick subchannel with custom request hash header where the rpc hash hits server1. + Metadata headers = new Metadata(); + headers.put(CUSTOM_METADATA_KEY, "FakeSocketAddress-server1_0"); + PickSubchannelArgs args = + new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), + headers, + CallOptions.DEFAULT, + new PickDetailsConsumer() {}); + SubchannelPicker picker = pickerCaptor.getValue(); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); + } + + @Test + public void deterministicPickWithRequestHashHeader_multipleHeaderValues() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, CUSTOM_REQUEST_HASH_HEADER); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + InOrder inOrder = Mockito.inOrder(helper); + + // Bring all subchannels to READY. + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, CSI_READY); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + } + + // Pick subchannel with custom request hash header with multiple values for the same key where + // the rpc hash hits server1. + Metadata headers = new Metadata(); + headers.put(CUSTOM_METADATA_KEY, "FakeSocketAddress-server0_0"); + headers.put(CUSTOM_METADATA_KEY, "FakeSocketAddress-server1_0"); + PickSubchannelArgs args = + new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), + headers, + CallOptions.DEFAULT, + new PickDetailsConsumer() {}); + SubchannelPicker picker = pickerCaptor.getValue(); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); + } + + @Test + public void pickWithRandomHash_allSubchannelsReady() { + loadBalancer = new RingHashLoadBalancer(helper, new FakeRandom()); + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(2, 2, "dummy-random-hash"); + List servers = createWeightedServerAddrs(1, 1); + initializeLbSubchannels(config, servers); + InOrder inOrder = Mockito.inOrder(helper); + + // Bring all subchannels to READY. + Map pickCounts = new HashMap<>(); + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, CSI_READY); + pickCounts.put(subchannel.getAddresses(), 0); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + } + + // Pick subchannel 100 times with random hash. + SubchannelPicker picker = pickerCaptor.getValue(); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + for (int i = 0; i < 100; ++i) { + Subchannel pickedSubchannel = picker.pickSubchannel(args).getSubchannel(); + EquivalentAddressGroup addr = pickedSubchannel.getAddresses(); + pickCounts.put(addr, pickCounts.get(addr) + 1); + } + + // Verify the distribution is uniform where server0 and server1 are exactly picked 50 times. + assertThat(pickCounts.get(servers.get(0))).isEqualTo(50); + assertThat(pickCounts.get(servers.get(1))).isEqualTo(50); + } + + @Test + public void pickWithRandomHash_atLeastOneSubchannelConnecting() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, "dummy-random-hash"); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + + // Bring one subchannel to CONNECTING. + deliverSubchannelState(getSubChannel(servers.get(0)), CSI_CONNECTING); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + + // Pick subchannel with random hash does not trigger connection. + SubchannelPicker picker = pickerCaptor.getValue(); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel()).isNull(); // buffer request + verifyConnection(0); + } + + @Test + public void pickWithRandomHash_firstSubchannelInTransientFailure_remainingSubchannelsIdle() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, "dummy-random-hash"); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + + // Bring one subchannel to TRANSIENT_FAILURE. + deliverSubchannelUnreachable(getSubChannel(servers.get(0))); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verifyConnection(1); + + // Pick subchannel with random hash does trigger connection by walking the ring + // and choosing the first (at most one) IDLE subchannel along the way. + SubchannelPicker picker = pickerCaptor.getValue(); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel()).isNull(); // buffer request + verifyConnection(1); + } + private Subchannel getSubChannel(EquivalentAddressGroup eag) { return subchannels.get(Collections.singletonList(eag)); } @@ -411,7 +556,7 @@ private Subchannel getSubChannel(EquivalentAddressGroup eag) { @Test public void skipFailingHosts_pickNextNonFailingHost() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( @@ -420,7 +565,7 @@ public void skipFailingHosts_pickNextNonFailingHost() { assertThat(addressesAcceptanceStatus.isOk()).isTrue(); // Create subchannel for the first address - loadBalancer.getChildLbStateEag(servers.get(0)).getCurrentPicker() + loadBalancer.getChildLbStates().iterator().next().getCurrentPicker() .pickSubchannel(getDefaultPickSubchannelArgs(hashFunc.hashVoid())); verifyConnection(1); @@ -438,13 +583,15 @@ public void skipFailingHosts_pickNextNonFailingHost() { getSubChannel(servers.get(0)), ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("unreachable"))); - verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verify(helper, atLeastOnce()).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getSubchannel()).isNull(); // buffer request - int expectedTimes = PickFirstLoadBalancerProvider.isEnabledHappyEyeballs() ? 1 : 2; // verify kicked off connection to server2 + int expectedTimes = PickFirstLoadBalancerProvider.isEnabledHappyEyeballs() + || !PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 2 : 1; + verify(getSubChannel(servers.get(1)), times(expectedTimes)).requestConnection(); assertThat(subchannels.size()).isEqualTo(2); // no excessive connection @@ -479,7 +626,7 @@ private PickSubchannelArgs getDefaultPickSubchannelArgsForServer(int serverid) { @Test public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -501,8 +648,8 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { getSubchannel(servers, 2), ConnectivityStateInfo.forTransientFailure( Status.PERMISSION_DENIED.withDescription("permission denied"))); - verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - verifyConnection(0); + verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + verifyConnection(2); PickResult result = pickerCaptor.getValue().pickSubchannel(args); // activate last subchannel assertThat(result.getStatus().isOk()).isTrue(); int expectedCount = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 0 : 1; @@ -545,7 +692,7 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { @Test public void removingAddressShutdownSubchannel() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List svs1 = createWeightedServerAddrs(1, 1, 1); List subchannels1 = initializeLbSubchannels(config, svs1, STAY_IN_CONNECTING); @@ -562,7 +709,7 @@ public void removingAddressShutdownSubchannel() { @Test public void allSubchannelsInTransientFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -574,7 +721,7 @@ public void allSubchannelsInTransientFailure() { } verify(helper, atLeastOnce()) .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - verifyConnection(0); + verifyConnection(2); // Picking subchannel triggers connection. RPC hash hits server0. PickSubchannelArgs args = getDefaultPickSubchannelArgsForServer(0); @@ -589,16 +736,17 @@ public void allSubchannelsInTransientFailure() { @Test public void firstSubchannelIdle() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); - // Go to TF does nothing, though PF will try to reconnect after backoff + // As per gRFC A61, entering TF triggers a proactive connection attempt + // on an IDLE subchannel because no other subchannel is currently CONNECTING. deliverSubchannelState(getSubchannel(servers, 1), ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("unreachable"))); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - verifyConnection(0); + verifyConnection(1); // Picking subchannel triggers connection. RPC hash hits server0. PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); @@ -610,7 +758,7 @@ public void firstSubchannelIdle() { @Test public void firstSubchannelConnecting() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -634,7 +782,7 @@ private Subchannel getSubchannel(List servers, int serve @Test public void firstSubchannelFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); List subchannelList = @@ -649,7 +797,7 @@ public void firstSubchannelFailure() { ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("unreachable"))); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - verifyConnection(0); + verifyConnection(1); // Per GRFC A61 Picking subchannel should no longer request connections that were failing PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); @@ -665,7 +813,7 @@ public void firstSubchannelFailure() { @Test public void secondSubchannelConnecting() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -677,7 +825,7 @@ public void secondSubchannelConnecting() { Subchannel firstSubchannel = getSubchannel(servers, 0); deliverSubchannelUnreachable(firstSubchannel); - verifyConnection(0); + verifyConnection(1); deliverSubchannelState(getSubchannel(servers, 2), CSI_CONNECTING); verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -686,7 +834,7 @@ public void secondSubchannelConnecting() { // Picking subchannel when idle triggers connection. deliverSubchannelState(getSubchannel(servers, 2), ConnectivityStateInfo.forNonError(IDLE)); - verifyConnection(0); + verifyConnection(1); PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); @@ -696,7 +844,7 @@ public void secondSubchannelConnecting() { @Test public void secondSubchannelFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -710,7 +858,7 @@ public void secondSubchannelFailure() { deliverSubchannelUnreachable(firstSubchannel); deliverSubchannelUnreachable(getSubchannel(servers, 2)); verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - verifyConnection(0); + verifyConnection(2); // Picking subchannel triggers connection. PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); @@ -723,7 +871,7 @@ public void secondSubchannelFailure() { @Test public void thirdSubchannelConnecting() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -740,7 +888,7 @@ public void thirdSubchannelConnecting() { deliverSubchannelState(getSubchannel(servers, 1), CSI_CONNECTING); verify(helper, atLeastOnce()) .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - verifyConnection(0); + verifyConnection(2); // Picking subchannel should not trigger connection per gRFC A61. PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); @@ -752,7 +900,7 @@ public void thirdSubchannelConnecting() { @Test public void stickyTransientFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -762,7 +910,7 @@ public void stickyTransientFailure() { deliverSubchannelUnreachable(firstSubchannel); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - verifyConnection(0); + verifyConnection(1); reset(helper); deliverSubchannelState(firstSubchannel, ConnectivityStateInfo.forNonError(IDLE)); @@ -781,7 +929,7 @@ public void stickyTransientFailure() { @Test public void largeWeights() { - RingHashConfig config = new RingHashConfig(10000, 100000); // large ring + RingHashConfig config = new RingHashConfig(10000, 100000, ""); // large ring List servers = createWeightedServerAddrs(Integer.MAX_VALUE, 10, 100); // MAX:10:100 @@ -819,7 +967,7 @@ public void largeWeights() { @Test public void hostSelectionProportionalToWeights() { - RingHashConfig config = new RingHashConfig(10000, 100000); // large ring + RingHashConfig config = new RingHashConfig(10000, 100000, ""); // large ring List servers = createWeightedServerAddrs(1, 10, 100); // 1:10:100 initializeLbSubchannels(config, servers); @@ -862,7 +1010,7 @@ public void nameResolutionErrorWithNoActiveSubchannels() { @Test public void nameResolutionErrorWithActiveSubchannels() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1); initializeLbSubchannels(config, servers, DO_NOT_VERIFY, DO_NOT_RESET_HELPER); @@ -884,7 +1032,7 @@ public void nameResolutionErrorWithActiveSubchannels() { @Test public void duplicateAddresses() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createRepeatedServerAddrs(1, 2, 3); initializeLbSubchannels(config, servers, DO_NOT_VERIFY); @@ -903,6 +1051,116 @@ public void duplicateAddresses() { assertThat(description).contains("Address: FakeSocketAddress-server2, count: 3"); } + @Test + public void subchannelHealthObserved() throws Exception { + // Only the new PF policy observes the new separate listener for health + PickFirstLoadBalancerProviderAccessor.setEnableNewPickFirst(true); + // PickFirst does most of this work. If the test fails, check IS_PETIOLE_POLICY + Map healthListeners = new HashMap<>(); + loadBalancer = new RingHashLoadBalancer(new ForwardingLoadBalancerHelper() { + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + Subchannel subchannel = super.createSubchannel(args.toBuilder() + .setAttributes(args.getAttributes().toBuilder() + .set(LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY, true) + .build()) + .build()); + healthListeners.put( + subchannel, args.getOption(LoadBalancer.HEALTH_CONSUMER_LISTENER_ARG_KEY)); + return subchannel; + } + + @Override + protected Helper delegate() { + return helper; + } + }); + + InOrder inOrder = Mockito.inOrder(helper); + List servers = createWeightedServerAddrs(1, 1); + initializeLbSubchannels(new RingHashConfig(10, 100, ""), servers); + Subchannel subchannel0 = subchannels.get(Collections.singletonList(servers.get(0))); + Subchannel subchannel1 = subchannels.get(Collections.singletonList(servers.get(1))); + + // Subchannels go READY, but the LB waits for health + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + } + inOrder.verify(helper, times(0)).updateBalancingState(eq(READY), any(SubchannelPicker.class)); + + // Health results lets subchannels go READY + healthListeners.get(subchannel0).onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + healthListeners.get(subchannel1).onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(helper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); + SubchannelPicker picker = pickerCaptor.getValue(); + Random random = new Random(1); + Set picks = new HashSet<>(); + for (int i = 0; i < 10; i++) { + picks.add( + picker.pickSubchannel(getDefaultPickSubchannelArgs(random.nextLong())).getSubchannel()); + } + assertThat(picks).containsExactly(subchannel0, subchannel1); + + // Unhealthy subchannel skipped + healthListeners.get(subchannel0).onSubchannelState( + ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription("oh no"))); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + picker = pickerCaptor.getValue(); + random.setSeed(1); + picks.clear(); + for (int i = 0; i < 10; i++) { + picks.add( + picker.pickSubchannel(getDefaultPickSubchannelArgs(random.nextLong())).getSubchannel()); + } + assertThat(picks).containsExactly(subchannel1); + } + + @Test + public void config_equalsTester() { + new EqualsTester() + .addEqualityGroup( + new RingHashConfig(1, 2, "headerA"), + new RingHashConfig(1, 2, "headerA")) + .addEqualityGroup(new RingHashConfig(1, 1, "headerA")) + .addEqualityGroup(new RingHashConfig(2, 2, "headerA")) + .addEqualityGroup(new RingHashConfig(1, 2, "headerB")) + .addEqualityGroup(new RingHashConfig(1, 2, "")) + .testEquals(); + } + + @Test + public void tfWithoutConnectingChild_triggersIdleChildConnection() { + RingHashConfig config = new RingHashConfig(10, 100, ""); + List servers = createWeightedServerAddrs(1, 1); + + initializeLbSubchannels(config, servers); + + Subchannel tfSubchannel = getSubchannel(servers, 0); + Subchannel idleSubchannel = getSubchannel(servers, 1); + + deliverSubchannelUnreachable(tfSubchannel); + + Subchannel requested = connectionRequestedQueue.poll(); + assertThat(requested).isSameInstanceAs(idleSubchannel); + assertThat(connectionRequestedQueue.poll()).isNull(); + } + + @Test + public void tfWithReadyChild_doesNotTriggerIdleChildConnection() { + RingHashConfig config = new RingHashConfig(10, 100, ""); + List servers = createWeightedServerAddrs(1, 1, 1); + + initializeLbSubchannels(config, servers); + + Subchannel tfSubchannel = getSubchannel(servers, 0); + Subchannel readySubchannel = getSubchannel(servers, 1); + + deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelUnreachable(tfSubchannel); + + assertThat(connectionRequestedQueue.poll()).isNull(); + } + private List initializeLbSubchannels(RingHashConfig config, List servers, InitializationFlags... initFlags) { @@ -947,8 +1205,6 @@ private List initializeLbSubchannels(RingHashConfig config, for (ChildLbState childLbState : loadBalancer.getChildLbStates()) { childLbState.getCurrentPicker() .pickSubchannel(getDefaultPickSubchannelArgs(hashFunc.hashVoid())); - assertThat(childLbState.getResolvedAddresses().getAttributes().get(IS_PETIOLE_POLICY)) - .isTrue(); } if (doVerifies) { @@ -1012,7 +1268,7 @@ private static List createWeightedServerAddrs(long... we for (int i = 0; i < weights.length; i++) { SocketAddress addr = new FakeSocketAddress("server" + i); Attributes attr = Attributes.newBuilder().set( - InternalXdsAttributes.ATTR_SERVER_WEIGHT, weights[i]).build(); + XdsAttributes.ATTR_SERVER_WEIGHT, weights[i]).build(); EquivalentAddressGroup eag = new EquivalentAddressGroup(addr, attr); addrs.add(eag); } @@ -1095,6 +1351,30 @@ public void requestConnection() { } } + private static final class FakeRandom implements ThreadSafeRandom { + int counter = 0; + + @Override + public int nextInt(int bound) { + throw new UnsupportedOperationException("Should not be called"); + } + + @Override + public long nextLong() { + ++counter; + if (counter % 2 == 0) { + return XxHash64.INSTANCE.hashAsciiString("FakeSocketAddress-server0_0"); + } else { + return XxHash64.INSTANCE.hashAsciiString("FakeSocketAddress-server1_0"); + } + } + + @Override + public long nextLong(long bound) { + throw new UnsupportedOperationException("Should not be called"); + } + } + enum InitializationFlags { DO_NOT_VERIFY, RESET_SUBCHANNEL_MOCKS, diff --git a/xds/src/test/java/io/grpc/xds/RouterFilterTest.java b/xds/src/test/java/io/grpc/xds/RouterFilterTest.java new file mode 100644 index 00000000000..30fd8a6dc38 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/RouterFilterTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link RouterFilter}. */ +@RunWith(JUnit4.class) +public class RouterFilterTest { + private static final RouterFilter.Provider FILTER_PROVIDER = new RouterFilter.Provider(); + + @Test + public void filterType_clientAndServer() { + assertThat(FILTER_PROVIDER.isClientFilter()).isTrue(); + assertThat(FILTER_PROVIDER.isServerFilter()).isTrue(); + } + +} diff --git a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java index ee164938b2d..29b149f166f 100644 --- a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java @@ -18,22 +18,39 @@ import static com.google.common.truth.Truth.assertThat; +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.OAuth2Credentials; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.CallCredentials; +import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; +import io.grpc.Metadata; +import io.grpc.MetricRecorder; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.ObjectPool; import io.grpc.xds.SharedXdsClientPoolProvider.RefCountedXdsClientObjectPool; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.EnvoyProtoData.Node; import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceWatcher; import io.grpc.xds.client.XdsInitializationException; import java.util.Collections; +import java.util.concurrent.TimeUnit; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; @@ -47,25 +64,34 @@ public class SharedXdsClientPoolProviderTest { private static final String SERVER_URI = "trafficdirector.googleapis.com"; @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); private final Node node = Node.newBuilder().setId("SharedXdsClientPoolProviderTest").build(); + private final MetricRecorder metricRecorder = new MetricRecorder() {}; private static final String DUMMY_TARGET = "dummy"; + static final Metadata.Key AUTHORIZATION_METADATA_KEY = + Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER); @Mock private GrpcBootstrapperImpl bootstrapper; + @Mock private ResourceWatcher ldsResourceWatcher; + @Deprecated @Test - public void noServer() throws XdsInitializationException { + public void sharedXdsClientObjectPool_deprecated() throws XdsInitializationException { + ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = - BootstrapInfo.builder().servers(Collections.emptyList()).node(node).build(); + BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper); - thrown.expect(XdsInitializationException.class); - thrown.expectMessage("No xDS server provided"); - provider.getOrCreate(DUMMY_TARGET); assertThat(provider.get(DUMMY_TARGET)).isNull(); + ObjectPool xdsClientPool = + provider.getOrCreate(DUMMY_TARGET, metricRecorder, null); + verify(bootstrapper).bootstrap(); + assertThat(provider.getOrCreate(DUMMY_TARGET, bootstrapInfo, metricRecorder)) + .isSameInstanceAs(xdsClientPool); + assertThat(provider.get(DUMMY_TARGET)).isNotNull(); + assertThat(provider.get(DUMMY_TARGET)).isSameInstanceAs(xdsClientPool); + verifyNoMoreInteractions(bootstrapper); } @Test @@ -73,13 +99,14 @@ public void sharedXdsClientObjectPool() throws XdsInitializationException { ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper); assertThat(provider.get(DUMMY_TARGET)).isNull(); - ObjectPool xdsClientPool = provider.getOrCreate(DUMMY_TARGET); - verify(bootstrapper).bootstrap(); - assertThat(provider.getOrCreate(DUMMY_TARGET)).isSameInstanceAs(xdsClientPool); + ObjectPool xdsClientPool = + provider.getOrCreate(DUMMY_TARGET, bootstrapInfo, metricRecorder); + verify(bootstrapper, never()).bootstrap(); + assertThat(provider.getOrCreate(DUMMY_TARGET, bootstrapInfo, metricRecorder)) + .isSameInstanceAs(xdsClientPool); assertThat(provider.get(DUMMY_TARGET)).isNotNull(); assertThat(provider.get(DUMMY_TARGET)).isSameInstanceAs(xdsClientPool); verifyNoMoreInteractions(bootstrapper); @@ -90,8 +117,9 @@ public void refCountedXdsClientObjectPool_delayedCreation() { ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(); RefCountedXdsClientObjectPool xdsClientPool = - new RefCountedXdsClientObjectPool(bootstrapInfo, DUMMY_TARGET); + provider.new RefCountedXdsClientObjectPool(bootstrapInfo, DUMMY_TARGET, metricRecorder); assertThat(xdsClientPool.getXdsClientForTest()).isNull(); XdsClient xdsClient = xdsClientPool.getObject(); assertThat(xdsClientPool.getXdsClientForTest()).isNotNull(); @@ -103,8 +131,9 @@ public void refCountedXdsClientObjectPool_refCounted() { ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(); RefCountedXdsClientObjectPool xdsClientPool = - new RefCountedXdsClientObjectPool(bootstrapInfo, DUMMY_TARGET); + provider.new RefCountedXdsClientObjectPool(bootstrapInfo, DUMMY_TARGET, metricRecorder); // getObject once XdsClient xdsClient = xdsClientPool.getObject(); assertThat(xdsClient).isNotNull(); @@ -123,8 +152,9 @@ public void refCountedXdsClientObjectPool_getObjectCreatesNewInstanceIfAlreadySh ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(); RefCountedXdsClientObjectPool xdsClientPool = - new RefCountedXdsClientObjectPool(bootstrapInfo, DUMMY_TARGET); + provider.new RefCountedXdsClientObjectPool(bootstrapInfo, DUMMY_TARGET, metricRecorder); XdsClient xdsClient1 = xdsClientPool.getObject(); assertThat(xdsClientPool.returnObject(xdsClient1)).isNull(); assertThat(xdsClient1.isShutDown()).isTrue(); @@ -133,4 +163,61 @@ public void refCountedXdsClientObjectPool_getObjectCreatesNewInstanceIfAlreadySh assertThat(xdsClient2).isNotSameInstanceAs(xdsClient1); xdsClientPool.returnObject(xdsClient2); } + + private class CallCredsServerInterceptor implements ServerInterceptor { + private SettableFuture tokenFuture = SettableFuture.create(); + + @Override + public ServerCall.Listener interceptCall( + ServerCall serverCall, + Metadata metadata, + ServerCallHandler next) { + tokenFuture.set(metadata.get(AUTHORIZATION_METADATA_KEY)); + return next.startCall(serverCall, metadata); + } + + public String getTokenWithTimeout(long timeout, TimeUnit unit) throws Exception { + return tokenFuture.get(timeout, unit); + } + } + + @Test + public void xdsClient_usesCallCredentials() throws Exception { + // Set up fake xDS server + XdsTestControlPlaneService fakeXdsService = new XdsTestControlPlaneService(); + CallCredsServerInterceptor callCredentialsInterceptor = new CallCredsServerInterceptor(); + Server xdsServer = + Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(fakeXdsService) + .intercept(callCredentialsInterceptor) + .build() + .start(); + String xdsServerUri = "localhost:" + xdsServer.getPort(); + + // Set up bootstrap & xDS client pool provider + ServerInfo server = ServerInfo.create(xdsServerUri, InsecureChannelCredentials.create()); + BootstrapInfo bootstrapInfo = + BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(); + + // Create custom xDS transport CallCredentials + CallCredentials sampleCreds = + MoreCallCredentials.from( + OAuth2Credentials.create(new AccessToken("token", /* expirationTime= */ null))); + + // Create xDS client that uses the CallCredentials on the transport + ObjectPool xdsClientPool = + provider.getOrCreate("target", bootstrapInfo, metricRecorder, sampleCreds); + XdsClient xdsClient = xdsClientPool.getObject(); + xdsClient.watchXdsResource( + XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher); + + // Wait for xDS server to get the request and verify that it received the CallCredentials + assertThat(callCredentialsInterceptor.getTokenWithTimeout(5, TimeUnit.SECONDS)) + .isEqualTo("Bearer token"); + + // Clean up + xdsClientPool.returnObject(xdsClient); + xdsServer.shutdownNow(); + } } diff --git a/xds/src/test/java/io/grpc/xds/StatefulFilter.java b/xds/src/test/java/io/grpc/xds/StatefulFilter.java new file mode 100644 index 00000000000..4ef662c7ccd --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/StatefulFilter.java @@ -0,0 +1,174 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Message; +import io.grpc.ServerInterceptor; +import java.util.ConcurrentModificationException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.IntStream; +import javax.annotation.Nullable; + +/** + * Unlike most singleton-based filters, each StatefulFilter object has a distinct identity. + */ +class StatefulFilter implements Filter { + + static final String DEFAULT_TYPE_URL = "type.googleapis.com/grpc.test.StatefulFilter"; + private final AtomicBoolean shutdown = new AtomicBoolean(); + + final int idx; + @Nullable volatile String lastCfg = null; + + public StatefulFilter(int idx) { + this.idx = idx; + } + + public boolean isShutdown() { + return shutdown.get(); + } + + @Override + public void close() { + if (!shutdown.compareAndSet(false, true)) { + throw new ConcurrentModificationException( + "Unexpected: StatefulFilter#close called multiple times"); + } + } + + @Nullable + @Override + public ServerInterceptor buildServerInterceptor( + FilterConfig config, + @Nullable FilterConfig overrideConfig) { + Config cfg = (Config) config; + // TODO(sergiitk): to be replaced when name argument passed to the constructor. + lastCfg = cfg.getConfig(); + return null; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder().append("StatefulFilter{") + .append("idx=").append(idx); + if (lastCfg != null) { + sb.append(", name=").append(lastCfg); + } + return sb.append("}").toString(); + } + + static final class Provider implements Filter.Provider { + + private final String typeUrl; + private final ConcurrentMap instances = new ConcurrentHashMap<>(); + + volatile int counter; + + Provider() { + this(DEFAULT_TYPE_URL); + } + + Provider(String typeUrl) { + this.typeUrl = typeUrl; + } + + @Override + public String[] typeUrls() { + return new String[]{ typeUrl }; + } + + @Override + public boolean isClientFilter() { + return true; + } + + @Override + public boolean isServerFilter() { + return true; + } + + @Override + public synchronized StatefulFilter newInstance(String name) { + StatefulFilter filter = new StatefulFilter(counter++); + instances.put(filter.idx, filter); + return filter; + } + + public synchronized StatefulFilter getInstance(int idx) { + return instances.get(idx); + } + + public synchronized ImmutableList getAllInstances() { + return IntStream.range(0, counter).mapToObj(this::getInstance).collect(toImmutableList()); + } + + @SuppressWarnings("UnusedMethod") + public synchronized int getCount() { + return counter; + } + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + return ConfigOrError.fromConfig(Config.fromProto(rawProtoMessage, typeUrl)); + } + + @Override + public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { + return ConfigOrError.fromConfig(Config.fromProto(rawProtoMessage, typeUrl)); + } + } + + + static final class Config implements FilterConfig { + + private final String typeUrl; + private final String config; + + public Config(String config, String typeUrl) { + this.config = config; + this.typeUrl = typeUrl; + } + + public Config(String config) { + this(config, DEFAULT_TYPE_URL); + } + + public Config() { + this("", DEFAULT_TYPE_URL); + } + + public static Config fromProto(Message rawProtoMessage, String typeUrl) { + checkNotNull(rawProtoMessage, "rawProtoMessage"); + return new Config(rawProtoMessage.toString(), typeUrl); + } + + public String getConfig() { + return config; + } + + @Override + public String typeUrl() { + return typeUrl; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java index d6240fb09bb..691615762bf 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; import io.grpc.LoadBalancer.PickResult; @@ -30,7 +31,6 @@ import java.util.List; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; @@ -42,9 +42,6 @@ */ @RunWith(JUnit4.class) public class WeightedRandomPickerTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); @@ -128,20 +125,18 @@ public long nextLong(long bound) { public void emptyList() { List emptyList = new ArrayList<>(); - thrown.expect(IllegalArgumentException.class); - new WeightedRandomPicker(emptyList); + assertThrows(IllegalArgumentException.class, () -> new WeightedRandomPicker(emptyList)); } @Test public void negativeWeight() { - thrown.expect(IllegalArgumentException.class); - new WeightedChildPicker(-1, childPicker0); + assertThrows(IllegalArgumentException.class, () -> new WeightedChildPicker(-1, childPicker0)); } @Test public void overWeightSingle() { - thrown.expect(IllegalArgumentException.class); - new WeightedChildPicker(Integer.MAX_VALUE * 3L, childPicker0); + assertThrows(IllegalArgumentException.class, + () -> new WeightedChildPicker(Integer.MAX_VALUE * 3L, childPicker0)); } @Test @@ -152,8 +147,8 @@ public void overWeightAggregate() { new WeightedChildPicker(Integer.MAX_VALUE, childPicker1), new WeightedChildPicker(10, childPicker2)); - thrown.expect(IllegalArgumentException.class); - new WeightedRandomPicker(weightedChildPickers, fakeRandom); + assertThrows(IllegalArgumentException.class, + () -> new WeightedRandomPicker(weightedChildPickers, fakeRandom)); } @Test diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java index ddde84ca842..7bd1590885e 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java @@ -111,6 +111,22 @@ public void parseLoadBalancingConfigDefaultValues() throws IOException { assertThat(config.errorUtilizationPenalty).isEqualTo(1.0F); } + @Test + public void parseLoadBalancingConfigCustomMetrics() throws IOException { + System.setProperty("GRPC_EXPERIMENTAL_WRR_CUSTOM_METRICS", "true"); + try { + String lbConfig = "{\"metricNamesForComputingUtilization\" : [\"foo\", \"bar\"]}"; + ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig( + parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + WeightedRoundRobinLoadBalancerConfig config = + (WeightedRoundRobinLoadBalancerConfig) configOrError.getConfig(); + assertThat(config.metricNamesForComputingUtilization).containsExactly("foo", "bar"); + } finally { + System.clearProperty("GRPC_EXPERIMENTAL_WRR_CUSTOM_METRICS"); + } + } + @SuppressWarnings("unchecked") private static Map parseJsonObject(String json) throws IOException { diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index 05ad1f56ece..d495521123a 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -19,9 +19,10 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.ConnectivityState.CONNECTING; import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.eq; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; @@ -32,14 +33,17 @@ import com.github.xds.data.orca.v3.OrcaLoadReport; import com.github.xds.service.orca.v3.OrcaLoadReportRequest; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.testing.EqualsTester; import com.google.protobuf.Duration; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; +import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.DoubleHistogramMetricInstrument; @@ -57,6 +61,7 @@ import io.grpc.Metadata; import io.grpc.MetricRecorder; import io.grpc.MetricSink; +import io.grpc.NameResolver; import io.grpc.NoopMetricSink; import io.grpc.ServerCall; import io.grpc.ServerServiceDefinition; @@ -68,6 +73,7 @@ import io.grpc.internal.PickFirstLoadBalancerProvider; import io.grpc.internal.TestUtils; import io.grpc.internal.testing.StreamRecorder; +import io.grpc.protobuf.ProtoUtils; import io.grpc.services.InternalCallMetricRecorder; import io.grpc.services.MetricReport; import io.grpc.stub.ClientCalls; @@ -80,6 +86,7 @@ import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedChildLbState; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker; +import io.grpc.xds.orca.OrcaOobUtilAccessor; import java.net.SocketAddress; import java.util.Arrays; import java.util.Collections; @@ -130,9 +137,6 @@ public class WeightedRoundRobinLoadBalancerTest { private final List servers = Lists.newArrayList(); private final Map, Subchannel> subchannels = Maps.newLinkedHashMap(); - private final Map mockToRealSubChannelMap = new HashMap<>(); - private final Map subchannelStateListeners = - Maps.newLinkedHashMap(); private final Queue> oobCalls = new ConcurrentLinkedQueue<>(); @@ -162,12 +166,17 @@ public void uncaughtException(Thread t, Throwable e) { private String channelTarget = "channel-target"; private String locality = "locality"; + private String backendService = "the-backend-service"; public WeightedRoundRobinLoadBalancerTest() { testHelperInstance = new TestHelper(); helper = mock(Helper.class, delegatesTo(testHelperInstance)); } + private static WeightedRoundRobinPicker getWrrPicker(SubchannelPicker picker) { + return (WeightedRoundRobinPicker) OrcaOobUtilAccessor.getDelegate(picker); + } + @Before public void setup() { for (int i = 0; i < 3; i++) { @@ -188,7 +197,7 @@ public ClientCall answer( return clientCall; } }); - testHelperInstance.setChannel(mockToRealSubChannelMap.get(sc), channel); + testHelperInstance.setChannel(sc, channel); subchannels.put(Arrays.asList(eag), sc); } wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker(), @@ -209,9 +218,42 @@ public void pickChildLbTF() throws Exception { .forTransientFailure(Status.UNAVAILABLE)); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - final WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getValue(); - weightedPicker.pickSubchannel(mockArgs); + final SubchannelPicker picker = pickerCaptor.getValue(); + picker.pickSubchannel(mockArgs); + } + + @Test + public void config_equalsTester() { + WeightedRoundRobinLoadBalancerConfig defaults = + WeightedRoundRobinLoadBalancerConfig.newBuilder().build(); + new EqualsTester() + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder().build(), + WeightedRoundRobinLoadBalancerConfig.newBuilder().build(), + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setBlackoutPeriodNanos(defaults.blackoutPeriodNanos).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setBlackoutPeriodNanos(5).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setWeightExpirationPeriodNanos(5).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setEnableOobLoadReport(true).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setOobReportingPeriodNanos(5).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setWeightUpdatePeriodNanos(5).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setErrorUtilizationPenalty(0.5F).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setErrorUtilizationPenalty(Float.NaN).build()) + .testEquals(); } @Test @@ -237,9 +279,9 @@ public void wrrLifeCycle() { eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); + getWrrPicker(pickerCaptor.getAllValues().get(0)); assertThat(weightedPicker.getChildren().size()).isEqualTo(1); - weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + weightedPicker = getWrrPicker(pickerCaptor.getAllValues().get(1)); assertThat(weightedPicker.getChildren().size()).isEqualTo(2); String weightedPickerStr = weightedPicker.toString(); assertThat(weightedPickerStr).contains("enableOobLoadReport=false"); @@ -248,16 +290,18 @@ public void wrrLifeCycle() { WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(expectedTasks); - assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(weightedChild1.getEag()); + assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(servers.get(0)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() .setWeightUpdatePeriodNanos(500_000_000L) //.5s @@ -300,20 +344,21 @@ public void enableOobLoadReportConfig() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.9, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(expectedTasks); PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); - assertThat(getAddresses(pickResult)) - .isEqualTo(weightedChild1.getEag()); + assertThat(getAddresses(pickResult)).isEqualTo(servers.get(0)); assertThat(pickResult.getStreamTracerFactory()).isNotNull(); // verify per-request listener assertThat(oobCalls.isEmpty()).isTrue(); @@ -325,10 +370,9 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on .setAttributes(affinity).build())); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor2.capture()); - weightedPicker = (WeightedRoundRobinPicker) pickerCaptor2.getAllValues().get(2); - pickResult = weightedPicker.pickSubchannel(mockArgs); - assertThat(getAddresses(pickResult)) - .isEqualTo(weightedChild1.getEag()); + SubchannelPicker rawPicker = pickerCaptor2.getAllValues().get(2); + pickResult = rawPicker.pickSubchannel(mockArgs); + assertThat(getAddresses(pickResult)).isEqualTo(servers.get(0)); assertThat(pickResult.getStreamTracerFactory()).isNull(); OrcaLoadReportRequest golden = OrcaLoadReportRequest.newBuilder().setReportInterval( Duration.newBuilder().setSeconds(20).setNanos(30000000).build()).build(); @@ -360,13 +404,16 @@ private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3, verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); + getWrrPicker(pickerCaptor.getAllValues().get(2)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r1); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r2); - weightedChild3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r3); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport(r1); + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport(r2); + weightedChild3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport(r3); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); Map pickCount = new HashMap<>(); @@ -375,16 +422,16 @@ private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3, pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(3); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - subchannel1PickRatio)) + assertThat(Math.abs(pickCount.get(servers.get(0)) / 10000.0 - subchannel1PickRatio)) .isAtMost(0.0002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - subchannel2PickRatio )) + assertThat(Math.abs(pickCount.get(servers.get(1)) / 10000.0 - subchannel2PickRatio )) .isAtMost(0.0002); - assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 10000.0 - subchannel3PickRatio )) + assertThat(Math.abs(pickCount.get(servers.get(2)) / 10000.0 - subchannel3PickRatio )) .isAtMost(0.0002); } private SubchannelStateListener getSubchannelStateListener(Subchannel mockSubChannel) { - return subchannelStateListeners.get(mockToRealSubChannelMap.get(mockSubChannel)); + return testHelperInstance.getSubchannelStateListener(mockSubChannel); } private static ChildLbState getChild(WeightedRoundRobinPicker picker, int index) { @@ -560,13 +607,15 @@ public void blackoutPeriod() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedCount = isEnabledHappyEyeballs() ? 2 : 1; @@ -578,8 +627,8 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on } assertThat(pickCount.size()).isEqualTo(2); // within blackout period, fallback to simple round robin - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 0.5)).isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 0.5)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 10000.0 - 0.5)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 10000.0 - 0.5)).isLessThan(0.002); assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1); pickCount = new HashMap<>(); @@ -589,10 +638,8 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on } assertThat(pickCount.size()).isEqualTo(2); // after blackout period - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 2.0 / 3)) - .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 1.0 / 3)) - .isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 10000.0 - 2.0 / 3)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 10000.0 - 1.0 / 3)).isLessThan(0.002); } private boolean isEnabledHappyEyeballs() { @@ -622,22 +669,23 @@ public void updateWeightTimer() { eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); + getWrrPicker(pickerCaptor.getAllValues().get(0)); assertThat(weightedPicker.getChildren().size()).isEqualTo(1); - weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + weightedPicker = getWrrPicker(pickerCaptor.getAllValues().get(1)); assertThat(weightedPicker.getChildren().size()).isEqualTo(2); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(expectedTasks); - assertThat(getAddressesFromPick(weightedPicker)) - .isEqualTo(weightedChild1.getEag()); + assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(servers.get(0)); assertThat(getNumFilteredPendingTasks()).isEqualTo(1); weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() .setWeightUpdatePeriodNanos(500_000_000L) //.5s @@ -646,19 +694,19 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); assertThat(getNumFilteredPendingTasks()).isEqualTo(1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); //timer fires, new weight updated expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(expectedTasks); - assertThat(getAddressesFromPick(weightedPicker)) - .isEqualTo(weightedChild2.getEag()); - assertThat(getAddressesFromPick(weightedPicker)) - .isEqualTo(weightedChild1.getEag()); + assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(servers.get(1)); + assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(servers.get(0)); } @Test @@ -680,13 +728,15 @@ public void weightExpired() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; @@ -697,10 +747,8 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3)) - .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3)) - .isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 1000.0 - 2.0 / 3)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 1000.0 - 1.0 / 3)).isLessThan(0.002); // weight expired, fallback to simple round robin assertThat(fakeClock.forwardTime(300, TimeUnit.SECONDS)).isEqualTo(1); @@ -710,10 +758,8 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 0.5)) - .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 0.5)) - .isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 1000.0 - 0.5)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 1000.0 - 0.5)).isLessThan(0.002); } @Test @@ -735,29 +781,20 @@ public void rrFallback() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(expectedTasks); - WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); - WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - Map qpsByChannel = ImmutableMap.of(weightedChild1.getEag(), 2, - weightedChild2.getEag(), 1); + Map qpsByChannel = ImmutableMap.of(servers.get(0), 2, + servers.get(1), 1); Map pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); EquivalentAddressGroup addresses = getAddresses(pickResult); pickCount.merge(addresses, 1, Integer::sum); - assertThat(pickResult.getStreamTracerFactory()).isNotNull(); - WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses); - childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( - InternalCallMetricRecorder.createMetricReport( - 0.1, 0, 0.1, qpsByChannel.get(addresses), 0, - new HashMap<>(), new HashMap<>(), new HashMap<>())); + reportLoadOnRpc(pickResult, 0.1, 0, 0.1, qpsByChannel.get(addresses), 0); } - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 1.0 / 2)) - .isAtMost(0.1); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 2)) - .isAtMost(0.1); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 1000.0 - 1.0 / 2)).isAtMost(0.1); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 1000.0 - 1.0 / 2)).isAtMost(0.1); // Identical to above except forwards time after each pick pickCount.clear(); @@ -765,19 +802,12 @@ childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLo PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); EquivalentAddressGroup addresses = getAddresses(pickResult); pickCount.merge(addresses, 1, Integer::sum); - assertThat(pickResult.getStreamTracerFactory()).isNotNull(); - WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses); - childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( - InternalCallMetricRecorder.createMetricReport( - 0.1, 0, 0.1, qpsByChannel.get(addresses), 0, - new HashMap<>(), new HashMap<>(), new HashMap<>())); + reportLoadOnRpc(pickResult, 0.1, 0, 0.1, qpsByChannel.get(addresses), 0); fakeClock.forwardTime(50, TimeUnit.MILLISECONDS); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3)) - .isAtMost(0.1); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3)) - .isAtMost(0.1); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 1000.0 - 2.0 / 3)).isAtMost(0.1); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 1000.0 - 1.0 / 3)).isAtMost(0.1); } private static EquivalentAddressGroup getAddresses(PickResult pickResult) { @@ -806,14 +836,15 @@ public void unknownWeightIsAvgWeight() { verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); + getWrrPicker(pickerCaptor.getAllValues().get(2)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); @@ -823,13 +854,10 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on pickCount.merge(result.getAddresses(), 1, Integer::sum); } assertThat(pickCount.size()).isEqualTo(3); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 4.0 / 9)) - .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 2.0 / 9)) - .isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 1000.0 - 4.0 / 9)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 1000.0 - 2.0 / 9)).isLessThan(0.002); // subchannel3's weight is average of subchannel1 and subchannel2 - assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 1000.0 - 3.0 / 9)) - .isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(2)) / 1000.0 - 3.0 / 9)).isLessThan(0.002); } @Test @@ -851,19 +879,21 @@ public void pickFromOtherThread() throws Exception { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); CyclicBarrier barrier = new CyclicBarrier(2); Map pickCount = new ConcurrentHashMap<>(); - pickCount.put(weightedChild1.getEag(), new AtomicInteger(0)); - pickCount.put(weightedChild2.getEag(), new AtomicInteger(0)); + pickCount.put(servers.get(0), new AtomicInteger(0)); + pickCount.put(servers.get(1), new AtomicInteger(0)); new Thread(new Runnable() { @Override public void run() { @@ -890,10 +920,8 @@ public void run() { barrier.await(); assertThat(pickCount.size()).isEqualTo(2); // after blackout period - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()).get() / 2000.0 - 2.0 / 3)) - .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()).get() / 2000.0 - 1.0 / 3)) - .isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(0)).get() / 2000.0 - 2.0 / 3)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(1)).get() / 2000.0 - 1.0 / 3)).isLessThan(0.002); } @Test(expected = NullPointerException.class) @@ -1094,7 +1122,7 @@ public void testImmediateWraparound() { .isLessThan(0.002); } } - + @Test public void testWraparound() { float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; @@ -1153,7 +1181,9 @@ public void removingAddressShutsdownSubchannel() { public void metrics() { // Give WRR some valid addresses to work with. Attributes attributesWithLocality = Attributes.newBuilder() - .set(WeightedTargetLoadBalancer.CHILD_NAME, locality).build(); + .set(WeightedTargetLoadBalancer.CHILD_NAME, locality) + .set(NameResolver.ATTR_BACKEND_SERVICE, backendService) + .build(); syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(attributesWithLocality).build())); @@ -1193,22 +1223,22 @@ public void metrics() { // Send one child LB state an ORCA update with some valid utilization/qps data so that weights // can be calculated, but it's still essentially round_robin Iterator childLbStates = wrr.getChildLbStates().iterator(); - ((WeightedChildLbState)childLbStates.next()).new OrcaReportListener( - weightedConfig.errorUtilizationPenalty).onLoadReport( - InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(), - new HashMap<>(), new HashMap<>())); + ((WeightedChildLbState) childLbStates.next()).new OrcaReportListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization) + .onLoadReport(InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), new HashMap<>())); fakeClock.forwardTime(1, TimeUnit.SECONDS); // Now send a second child LB state an ORCA update, so there's real weights - ((WeightedChildLbState)childLbStates.next()).new OrcaReportListener( - weightedConfig.errorUtilizationPenalty).onLoadReport( - InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(), - new HashMap<>(), new HashMap<>())); - ((WeightedChildLbState)childLbStates.next()).new OrcaReportListener( - weightedConfig.errorUtilizationPenalty).onLoadReport( - InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(), - new HashMap<>(), new HashMap<>())); + ((WeightedChildLbState) childLbStates.next()).new OrcaReportListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization) + .onLoadReport(InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), new HashMap<>())); + ((WeightedChildLbState) childLbStates.next()).new OrcaReportListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization) + .onLoadReport(InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), new HashMap<>())); // Let's reset the mock MetricsRecorder so that it's easier to verify what happened after the // weights were updated @@ -1298,12 +1328,257 @@ public void metricWithRealChannel() throws Exception { assertThat(recorder.getError()).isNull(); // Make sure at least one metric works. The other tests will make sure other metrics and the - // edge cases are working. - verify(metrics).addLongCounter( + // edge cases are working. Since this is racy, we just care it happened at least once. + verify(metrics, atLeast(1)).addLongCounter( argThat((instr) -> instr.getName().equals("grpc.lb.wrr.rr_fallback")), eq(1L), eq(Arrays.asList("directaddress:///wrr-metrics")), - eq(Arrays.asList(""))); + eq(Arrays.asList("", ""))); + } + + + @Test + public void customMetric_priority_overAppUtil() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost")).build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", 0.5); + // App util = 0.8 + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0.8, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // Custom metrics now take priority over app_util + // qps=1, util=0.5 -> weight=2.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(2.0), any(), + any()); + } + + @Test + public void customMetric_invalid_fallbackToAppUtil() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost")).build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + // custom metric is NaN, App util = 0.8 + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", Double.NaN); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0.8, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + + // Should fallback to App Util (0.8) + // qps=1, util=0.8 -> weight=1.25 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(1.25), any(), + any()); + } + + @Test + public void customMetric_mapLookup_used() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost")).build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", 0.5); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.5 -> weight=2.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(2.0), any(), + any()); + } + + @Test + public void customMetric_shouldFilterOutAndFallbackToCpu() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost")).build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + // custom metric is NaN, but CPU is 0.1 + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", Double.NaN); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + + // Should fallback to CPU (0.1) + // fallback to cpu: qps=1, util=0.1 -> weight=10.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(10.0), any(), + any()); + } + + @Test + public void customMetric_multipleMetrics_maxUsed() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization( + ImmutableList.of("named_metrics.cost", "named_metrics.score")) + .build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", 0.5); + namedMetrics.put("score", 0.8); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.8 (max of 0.5 and 0.8) -> weight=1.25 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(1.25), any(), + any()); + } + + @Test + public void customMetric_allInvalid_fallbackToCpu() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization( + ImmutableList.of("named_metrics.cost", "named_metrics.score", "named_metrics.other")) + .build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", Double.NaN); + namedMetrics.put("score", 0.0); + namedMetrics.put("other", -1.0); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.1 (fallback to cpu) -> weight=10.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(10.0), any(), + any()); + } + + @Test + public void customMetric_mixInvalidAndValid_validUsed() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost", + "named_metrics.score", "named_metrics.other1", "named_metrics.other2")) + .build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", Double.NaN); + namedMetrics.put("score", 0.5); + namedMetrics.put("other1", 0.0); + namedMetrics.put("other2", -123.0); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.5 -> weight=2.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(2.0), any(), + any()); } // Verifies that the MetricRecorder has been called to record a long counter value of 1 for the @@ -1315,7 +1590,10 @@ private void verifyLongCounterRecord(String name, int times, long value) { public boolean matches(LongCounterMetricInstrument longCounterInstrument) { return longCounterInstrument.getName().equals(name); } - }), eq(value), eq(Lists.newArrayList(channelTarget)), eq(Lists.newArrayList(locality))); + }), + eq(value), + eq(Lists.newArrayList(channelTarget)), + eq(Lists.newArrayList(locality, backendService))); } // Verifies that the MetricRecorder has been called to record a given double histogram value the @@ -1327,13 +1605,45 @@ private void verifyDoubleHistogramRecord(String name, int times, double value) { public boolean matches(DoubleHistogramMetricInstrument doubleHistogramInstrument) { return doubleHistogramInstrument.getName().equals(name); } - }), eq(value), eq(Lists.newArrayList(channelTarget)), eq(Lists.newArrayList(locality))); + }), + eq(value), + eq(Lists.newArrayList(channelTarget)), + eq(Lists.newArrayList(locality, backendService))); } private int getNumFilteredPendingTasks() { return AbstractTestHelper.getNumFilteredPendingTasks(fakeClock); } + private static final Metadata.Key ORCA_LOAD_METRICS_KEY = + Metadata.Key.of( + "endpoint-load-metrics-bin", + ProtoUtils.metadataMarshaller(OrcaLoadReport.getDefaultInstance())); + private static final ClientStreamTracer.StreamInfo STREAM_INFO = + ClientStreamTracer.StreamInfo.newBuilder().build(); + + private static void reportLoadOnRpc( + PickResult pickResult, + double cpuUtilization, + double applicationUtilization, + double memoryUtilization, + double qps, + double eps) { + ClientStreamTracer childTracer = pickResult.getStreamTracerFactory() + .newClientStreamTracer(STREAM_INFO, new Metadata()); + Metadata trailer = new Metadata(); + trailer.put( + ORCA_LOAD_METRICS_KEY, + OrcaLoadReport.newBuilder() + .setCpuUtilization(cpuUtilization) + .setApplicationUtilization(applicationUtilization) + .setMemUtilization(memoryUtilization) + .setRpsFractional(qps) + .setEps(eps) + .build()); + childTracer.inboundTrailers(trailer); + } + private static final class VerifyingScheduler { private final StaticStrideScheduler delegate; private final int max; @@ -1389,16 +1699,6 @@ public Map, Subchannel> getSubchannelMap() { return subchannels; } - @Override - public Map getMockToRealSubChannelMap() { - return mockToRealSubChannelMap; - } - - @Override - public Map getSubchannelStateListeners() { - return subchannelStateListeners; - } - @Override public MetricRecorder getMetricRecorder() { return mockMetricRecorder; diff --git a/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java index cc6cb98412c..55ff0cd8078 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java @@ -113,6 +113,7 @@ public String getPolicyName() { public LoadBalancer newLoadBalancer(Helper helper) { childHelpers.add(helper); LoadBalancer childBalancer = mock(LoadBalancer.class); + when(childBalancer.acceptResolvedAddresses(any())).thenReturn(Status.OK); childBalancers.add(childBalancer); fooLbCreated++; return childBalancer; @@ -139,6 +140,7 @@ public String getPolicyName() { public LoadBalancer newLoadBalancer(Helper helper) { childHelpers.add(helper); LoadBalancer childBalancer = mock(LoadBalancer.class); + when(childBalancer.acceptResolvedAddresses(any())).thenReturn(Status.OK); childBalancers.add(childBalancer); barLbCreated++; return childBalancer; @@ -180,7 +182,7 @@ public void tearDown() { } @Test - public void handleResolvedAddresses() { + public void acceptResolvedAddresses() { ArgumentCaptor resolvedAddressesCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); Attributes.Key fakeKey = Attributes.Key.create("fake_key"); @@ -203,12 +205,13 @@ public void handleResolvedAddresses() { eag2 = AddressFilter.setPathFilter(eag2, ImmutableList.of("target2")); EquivalentAddressGroup eag3 = new EquivalentAddressGroup(socketAddresses[3]); eag3 = AddressFilter.setPathFilter(eag3, ImmutableList.of("target3")); - weightedTargetLb.handleResolvedAddresses( + Status status = weightedTargetLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of(eag0, eag1, eag2, eag3)) .setAttributes(Attributes.newBuilder().set(fakeKey, fakeValue).build()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) .build()); + assertThat(status.isOk()).isTrue(); verify(helper).updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); assertThat(childBalancers).hasSize(4); assertThat(childHelpers).hasSize(4); @@ -216,7 +219,7 @@ public void handleResolvedAddresses() { assertThat(barLbCreated).isEqualTo(2); for (int i = 0; i < childBalancers.size(); i++) { - verify(childBalancers.get(i)).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(childBalancers.get(i)).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); ResolvedAddresses resolvedAddresses = resolvedAddressesCaptor.getValue(); assertThat(resolvedAddresses.getLoadBalancingPolicyConfig()).isEqualTo(configs[i]); assertThat(resolvedAddresses.getAttributes().get(fakeKey)).isEqualTo(fakeValue); @@ -226,6 +229,11 @@ public void handleResolvedAddresses() { .containsExactly(socketAddresses[i]); } + // Even when a child return an error from the update, the other children should still receive + // their updates. + Status acceptReturnStatus = Status.UNAVAILABLE.withDescription("Didn't like something"); + when(childBalancers.get(2).acceptResolvedAddresses(any())).thenReturn(acceptReturnStatus); + // Update new weighted target config for a typical workflow. // target0 removed. target1, target2, target3 changed weight and config. target4 added. int[] newWeights = new int[]{11, 22, 33, 44}; @@ -243,11 +251,12 @@ public void handleResolvedAddresses() { "target4", new WeightedPolicySelection( newWeights[3], newChildConfig(fooLbProvider, newConfigs[3]))); - weightedTargetLb.handleResolvedAddresses( + status = weightedTargetLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(newTargets)) .build()); + assertThat(status.getCode()).isEqualTo(acceptReturnStatus.getCode()); verify(helper, atLeast(2)) .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); assertThat(childBalancers).hasSize(5); @@ -258,7 +267,7 @@ public void handleResolvedAddresses() { verify(childBalancers.get(0)).shutdown(); for (int i = 1; i < childBalancers.size(); i++) { verify(childBalancers.get(i), atLeastOnce()) - .handleResolvedAddresses(resolvedAddressesCaptor.capture()); + .acceptResolvedAddresses(resolvedAddressesCaptor.capture()); assertThat(resolvedAddressesCaptor.getValue().getLoadBalancingPolicyConfig()) .isEqualTo(newConfigs[i - 1]); } @@ -286,7 +295,7 @@ public void handleNameResolutionError() { "target2", weightedLbConfig2, // {foo, 40, config3} "target3", weightedLbConfig3); - weightedTargetLb.handleResolvedAddresses( + weightedTargetLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) @@ -313,7 +322,7 @@ public void balancingStateUpdatedFromChildBalancers() { "target2", weightedLbConfig2, // {foo, 40, config3} "target3", weightedLbConfig3); - weightedTargetLb.handleResolvedAddresses( + weightedTargetLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) @@ -395,7 +404,7 @@ public void raceBetweenShutdownAndChildLbBalancingStateUpdate() { Map targets = ImmutableMap.of( "target0", weightedLbConfig0, "target1", weightedLbConfig1); - weightedTargetLb.handleResolvedAddresses( + weightedTargetLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) @@ -421,7 +430,7 @@ public void noDuplicateOverallBalancingStateUpdate() { weights[0], newChildConfig(fakeLbProvider, configs[0])), "target3", new WeightedPolicySelection( weights[3], newChildConfig(fakeLbProvider, configs[3]))); - weightedTargetLb.handleResolvedAddresses( + weightedTargetLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) @@ -470,9 +479,10 @@ static class FakeLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { helper.updateBalancingState( TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(Status.INTERNAL))); + return Status.OK; } @Override diff --git a/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java index fcf8c826d86..584c32738c5 100644 --- a/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java @@ -108,7 +108,7 @@ public void setUp() { } @Test - public void handleResolvedAddresses() { + public void acceptResolvedAddresses() { // A two locality cluster with a mock child LB policy. String localityOne = "localityOne"; String localityTwo = "localityTwo"; @@ -124,7 +124,7 @@ public void handleResolvedAddresses() { // Assert that the child policy and the locality weights were correctly mapped to a // WeightedTargetConfig. - verify(mockWeightedTargetLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(mockWeightedTargetLb).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); Object config = resolvedAddressesCaptor.getValue().getLoadBalancingPolicyConfig(); assertThat(config).isInstanceOf(WeightedTargetConfig.class); WeightedTargetConfig wtConfig = (WeightedTargetConfig) config; @@ -136,7 +136,7 @@ public void handleResolvedAddresses() { } @Test - public void handleResolvedAddresses_noLocalityWeights() { + public void acceptResolvedAddresses_noLocalityWeights() { // A two locality cluster with a mock child LB policy. Object childPolicy = newChildConfig(mockChildProvider, null); @@ -182,10 +182,10 @@ public void localityWeightAttributeNotPropagated() { // Assert that the child policy and the locality weights were correctly mapped to a // WeightedTargetConfig. - verify(mockWeightedTargetLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(mockWeightedTargetLb).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); //assertThat(resolvedAddressesCaptor.getValue().getAttributes() - // .get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHTS)).isNull(); + // .get(XdsAttributes.ATTR_LOCALITY_WEIGHTS)).isNull(); } @Test @@ -213,7 +213,7 @@ private Object newChildConfig(LoadBalancerProvider provider, Object config) { } private void deliverAddresses(WrrLocalityConfig config, List addresses) { - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(addresses).setLoadBalancingPolicyConfig(config) .build()); } @@ -254,9 +254,9 @@ public String toString() { } Attributes.Builder attrBuilder = Attributes.newBuilder() - .set(InternalXdsAttributes.ATTR_LOCALITY_NAME, locality); + .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, locality); if (localityWeight != null) { - attrBuilder.set(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT, localityWeight); + attrBuilder.set(XdsAttributes.ATTR_LOCALITY_WEIGHT, localityWeight); } EquivalentAddressGroup eag = new EquivalentAddressGroup(new FakeSocketAddress(name), diff --git a/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java b/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java new file mode 100644 index 00000000000..27ee8d22825 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java @@ -0,0 +1,610 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; +import io.grpc.MetricRecorder; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.internal.ExponentialBackoffPolicy; +import io.grpc.internal.FakeClock; +import io.grpc.internal.ObjectPool; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsListenerResource.LdsUpdate; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; +import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; +import io.grpc.xds.client.LoadReportClient; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClientImpl; +import io.grpc.xds.client.XdsClientMetricReporter; +import io.grpc.xds.client.XdsInitializationException; +import io.grpc.xds.client.XdsTransportFactory; +import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class XdsClientFallbackTest { + private static final Logger log = Logger.getLogger(XdsClientFallbackTest.class.getName()); + + private static final String MAIN_SERVER = "main-server"; + private static final String FALLBACK_SERVER = "fallback-server"; + private static final String DUMMY_TARGET = "TEST_TARGET"; + private static final String RDS_NAME = "route-config.googleapis.com"; + private static final String FALLBACK_RDS_NAME = "fallback-" + RDS_NAME; + private static final String CLUSTER_NAME = "cluster0"; + private static final String FALLBACK_CLUSTER_NAME = "fallback-" + CLUSTER_NAME; + private static final String EDS_NAME = "eds-service-0"; + private static final String FALLBACK_EDS_NAME = "fallback-" + EDS_NAME; + private static final HttpConnectionManager MAIN_HTTP_CONNECTION_MANAGER = + HttpConnectionManager.forRdsName(0, RDS_NAME, ImmutableList.of( + new Filter.NamedFilterConfig("terminal-filter", RouterFilter.ROUTER_CONFIG))); + private static final HttpConnectionManager FALLBACK_HTTP_CONNECTION_MANAGER = + HttpConnectionManager.forRdsName(0, FALLBACK_RDS_NAME, ImmutableList.of( + new Filter.NamedFilterConfig("terminal-filter", RouterFilter.ROUTER_CONFIG))); + private ObjectPool xdsClientPool; + private XdsClient xdsClient; + private boolean originalEnableXdsFallback; + private final FakeClock fakeClock = new FakeClock(); + private final MetricRecorder metricRecorder = new MetricRecorder() {}; + + @Mock + private XdsClientMetricReporter xdsClientMetricReporter; + + @Captor + private ArgumentCaptor> ldsUpdateCaptor; + @Captor + private ArgumentCaptor> rdsUpdateCaptor; + + private final XdsClient.ResourceWatcher raalLdsWatcher = + new XdsClient.ResourceWatcher() { + + @Override + public void onResourceChanged(StatusOr update) { + if (update.hasValue()) { + log.log(Level.FINE, "LDS update: " + update.getValue()); + } else { + log.log(Level.FINE, "LDS resource error: " + update.getStatus().getDescription()); + } + } + + @Override + public void onAmbientError(Status error) { + log.log(Level.FINE, "LDS ambient error: " + error.getDescription()); + } + }; + + @SuppressWarnings("unchecked") + private final XdsClient.ResourceWatcher ldsWatcher = + mock(XdsClient.ResourceWatcher.class, delegatesTo(raalLdsWatcher)); + @Mock + private XdsClient.ResourceWatcher ldsWatcher2; + + @Mock + private XdsClient.ResourceWatcher rdsWatcher; + @Mock + private XdsClient.ResourceWatcher rdsWatcher2; + @Mock + private XdsClient.ResourceWatcher rdsWatcher3; + + private final XdsClient.ResourceWatcher raalCdsWatcher = + new XdsClient.ResourceWatcher() { + + @Override + public void onResourceChanged(StatusOr update) { + if (update.hasValue()) { + log.log(Level.FINE, "CDS update: " + update.getValue()); + } else { + log.log(Level.FINE, "CDS resource error: " + update.getStatus().getDescription()); + } + } + + @Override + public void onAmbientError(Status error) { + // Logic from the old onError method for transient errors. + log.log(Level.FINE, "CDS ambient error: " + error.getDescription()); + } + }; + + @SuppressWarnings("unchecked") + private final XdsClient.ResourceWatcher cdsWatcher = + mock(XdsClient.ResourceWatcher.class, delegatesTo(raalCdsWatcher)); + @Mock + private XdsClient.ResourceWatcher cdsWatcher2; + + @Rule(order = 0) + public ControlPlaneRule mainXdsServer = + new ControlPlaneRule().setServerHostName(MAIN_SERVER); + + @Rule(order = 1) + public ControlPlaneRule fallbackServer = + new ControlPlaneRule().setServerHostName(MAIN_SERVER); + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + @Before + public void setUp() throws XdsInitializationException { + originalEnableXdsFallback = CommonBootstrapperTestUtils.setEnableXdsFallback(true); + if (mainXdsServer == null) { + throw new XdsInitializationException("Failed to create ControlPlaneRule for main TD server"); + } + setAdsConfig(mainXdsServer, MAIN_SERVER); + setAdsConfig(fallbackServer, FALLBACK_SERVER); + + SharedXdsClientPoolProvider clientPoolProvider = new SharedXdsClientPoolProvider(); + xdsClientPool = clientPoolProvider.getOrCreate( + DUMMY_TARGET, + new GrpcBootstrapperImpl().bootstrap(defaultBootstrapOverride()), + metricRecorder); + } + + @After + public void cleanUp() { + if (xdsClient != null) { + xdsClient = xdsClientPool.returnObject(xdsClient); + } + CommonBootstrapperTestUtils.setEnableXdsFallback(originalEnableXdsFallback); + } + + private static void setAdsConfig(ControlPlaneRule controlPlane, String serverName) { + InetSocketAddress edsInetSocketAddress = + (InetSocketAddress) controlPlane.getServer().getListenSockets().get(0); + boolean isMainServer = serverName.equals(MAIN_SERVER); + String rdsName = isMainServer + ? RDS_NAME + : FALLBACK_RDS_NAME; + String clusterName = isMainServer ? CLUSTER_NAME : FALLBACK_CLUSTER_NAME; + String edsName = isMainServer ? EDS_NAME : FALLBACK_EDS_NAME; + + controlPlane.setLdsConfig(ControlPlaneRule.buildServerListener(), + ControlPlaneRule.buildClientListener(MAIN_SERVER, rdsName)); + + controlPlane.setRdsConfig(rdsName, + XdsTestUtils.buildRouteConfiguration(MAIN_SERVER, rdsName, clusterName)); + controlPlane.setCdsConfig(clusterName, ControlPlaneRule.buildCluster(clusterName, edsName)); + + controlPlane.setEdsConfig(edsName, + ControlPlaneRule.buildClusterLoadAssignment(edsInetSocketAddress.getHostName(), + DataPlaneRule.ENDPOINT_HOST_NAME, edsInetSocketAddress.getPort(), edsName)); + log.log(Level.FINE, + String.format("Set ADS config for %s with address %s", serverName, edsInetSocketAddress)); + } + + // This is basically a control test to make sure everything is set up correctly. + @Test + public void everything_okay() { + mainXdsServer.restartXdsServer(); + fallbackServer.restartXdsServer(); + xdsClient = xdsClientPool.getObject(); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + verify(ldsWatcher, timeout(5000)).onResourceChanged(ldsUpdateCaptor.capture()); + assertThat(ldsUpdateCaptor.getValue().hasValue()).isTrue(); + assertThat(ldsUpdateCaptor.getValue().getValue()).isEqualTo( + LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER)); + + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_NAME, rdsWatcher); + verify(rdsWatcher, timeout(5000)).onResourceChanged(rdsUpdateCaptor.capture()); + assertThat(rdsUpdateCaptor.getValue().hasValue()).isTrue(); + } + + @Test + public void mainServerDown_fallbackServerUp() { + mainXdsServer.getServer().shutdownNow(); + fallbackServer.restartXdsServer(); + xdsClient = xdsClientPool.getObject(); + log.log(Level.FINE, "Fallback port = " + fallbackServer.getServer().getPort()); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + + verify(ldsWatcher, timeout(5000)).onResourceChanged( + StatusOr.fromValue(XdsListenerResource.LdsUpdate.forApiListener( + FALLBACK_HTTP_CONNECTION_MANAGER))); + } + + @Test + public void useBadAuthority() { + xdsClient = xdsClientPool.getObject(); + InOrder inOrder = inOrder(ldsWatcher, rdsWatcher, rdsWatcher2, rdsWatcher3); + + String badPrefix = "xdstp://authority.xds.bad/envoy.config.listener.v3.Listener/"; + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), + badPrefix + "listener.googleapis.com", ldsWatcher); + inOrder.verify(ldsWatcher, timeout(5000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue())); + + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), + badPrefix + "route-config.googleapis.bad", rdsWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), + badPrefix + "route-config2.googleapis.bad", rdsWatcher2); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), + badPrefix + "route-config3.googleapis.bad", rdsWatcher3); + inOrder.verify(rdsWatcher, timeout(5000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue())); + inOrder.verify(rdsWatcher2, timeout(5000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue())); + inOrder.verify(rdsWatcher3, timeout(5000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue())); + verify(rdsWatcher, never()).onResourceChanged(argThat(StatusOr::hasValue)); + + // even after an error, a valid one will still work + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher2); + verify(ldsWatcher2, timeout(5000)).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOr = ldsUpdateCaptor.getValue(); + assertThat(statusOr.hasValue()).isTrue(); + assertThat(statusOr.getValue()).isEqualTo( + XdsListenerResource.LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER)); + } + + @Test + public void both_down_restart_main() { + mainXdsServer.getServer().shutdownNow(); + fallbackServer.getServer().shutdownNow(); + xdsClient = xdsClientPool.getObject(); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + verify(ldsWatcher, timeout(5000).atLeastOnce()) + .onResourceChanged(argThat(statusOr -> !statusOr.hasValue())); + verify(ldsWatcher, never()).onResourceChanged(argThat(StatusOr::hasValue)); + xdsClient.watchXdsResource( + XdsRouteConfigureResource.getInstance(), RDS_NAME, rdsWatcher2); + verify(rdsWatcher2, timeout(5000).atLeastOnce()) + .onResourceChanged(argThat(statusOr -> !statusOr.hasValue())); + + mainXdsServer.restartXdsServer(); + + xdsClient.watchXdsResource( + XdsRouteConfigureResource.getInstance(), RDS_NAME, rdsWatcher); + + verify(ldsWatcher, timeout(16000)).onResourceChanged( + argThat(statusOr -> statusOr.hasValue() && statusOr.getValue().equals( + XdsListenerResource.LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER)))); + verify(rdsWatcher, timeout(5000)).onResourceChanged(argThat(StatusOr::hasValue)); + verify(rdsWatcher2, timeout(5000)).onResourceChanged(argThat(StatusOr::hasValue)); + } + + @Test + public void mainDown_fallbackUp_restart_main() { + mainXdsServer.getServer().shutdownNow(); + fallbackServer.restartXdsServer(); + xdsClient = xdsClientPool.getObject(); + InOrder inOrder = inOrder(ldsWatcher, rdsWatcher, cdsWatcher, cdsWatcher2); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + inOrder.verify(ldsWatcher, timeout(5000)).onResourceChanged( + StatusOr.fromValue(XdsListenerResource.LdsUpdate.forApiListener( + FALLBACK_HTTP_CONNECTION_MANAGER))); + + // Watch another resource, also from the fallback server. + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), FALLBACK_CLUSTER_NAME, cdsWatcher); + @SuppressWarnings("unchecked") + ArgumentCaptor> cdsUpdateCaptor1 = ArgumentCaptor.forClass(StatusOr.class); + inOrder.verify(cdsWatcher, timeout(5000)).onResourceChanged(cdsUpdateCaptor1.capture()); + assertThat(cdsUpdateCaptor1.getValue().getStatus().isOk()).isTrue(); + + assertThat(fallbackServer.getService().getSubscriberCounts() + .get("type.googleapis.com/envoy.config.listener.v3.Listener")).isEqualTo(1); + verifyNoSubscribers(mainXdsServer); + + mainXdsServer.restartXdsServer(); + + // The existing ldsWatcher should receive a new update from the main server. + // Note: This is not an inOrder verification because the timing of the switchover + // can vary. We just need to verify it happens. + verify(ldsWatcher, timeout(5000)).onResourceChanged( + StatusOr.fromValue(XdsListenerResource.LdsUpdate.forApiListener( + MAIN_HTTP_CONNECTION_MANAGER))); + + // Watch a new resource; should now come from the main server. + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_NAME, rdsWatcher); + @SuppressWarnings("unchecked") + ArgumentCaptor> rdsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + inOrder.verify(rdsWatcher, timeout(5000)).onResourceChanged(rdsUpdateCaptor.capture()); + assertThat(rdsUpdateCaptor.getValue().getStatus().isOk()).isTrue(); + verifyNoSubscribers(fallbackServer); + + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CLUSTER_NAME, cdsWatcher2); + @SuppressWarnings("unchecked") + ArgumentCaptor> cdsUpdateCaptor2 = ArgumentCaptor.forClass(StatusOr.class); + inOrder.verify(cdsWatcher2, timeout(5000)).onResourceChanged(cdsUpdateCaptor2.capture()); + assertThat(cdsUpdateCaptor2.getValue().getStatus().isOk()).isTrue(); + + verifyNoSubscribers(fallbackServer); + assertThat(mainXdsServer.getService().getSubscriberCounts() + .get("type.googleapis.com/envoy.config.listener.v3.Listener")).isEqualTo(1); + } + + private static void verifyNoSubscribers(ControlPlaneRule rule) { + for (Map.Entry me : rule.getService().getSubscriberCounts().entrySet()) { + String type = me.getKey(); + Integer count = me.getValue(); + assertWithMessage("Type with non-zero subscribers is: %s", type) + .that(count).isEqualTo(0); + } + } + + // This test takes a long time because of the 16 sec timeout for non-existent resource + @Test + public void connect_then_mainServerDown_fallbackServerUp() throws Exception { + mainXdsServer.restartXdsServer(); + fallbackServer.restartXdsServer(); + ExecutorService executor = Executors.newFixedThreadPool(1); + XdsTransportFactory xdsTransportFactory = new XdsTransportFactory() { + @Override + public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { + ChannelCredentials channelCredentials = + (ChannelCredentials) serverInfo.implSpecificConfig(); + return new GrpcXdsTransportFactory.GrpcXdsTransport( + Grpc.newChannelBuilder(serverInfo.target(), channelCredentials) + .executor(executor) + .build()); + } + }; + XdsClientImpl xdsClient = CommonBootstrapperTestUtils.createXdsClient( + new GrpcBootstrapperImpl().bootstrap(defaultBootstrapOverride()), + xdsTransportFactory, fakeClock, new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, xdsClientMetricReporter); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + + // Initial resource fetch from the main server + verify(ldsWatcher, timeout(5000)).onResourceChanged( + StatusOr.fromValue(LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER))); + + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_NAME, rdsWatcher); + verify(rdsWatcher, timeout(5000)).onResourceChanged(argThat(StatusOr::hasValue)); + + mainXdsServer.getServer().shutdownNow(); + // Sleep for the ADS stream disconnect to be processed and for the retry to fail. Between those + // two sleeps we need the fakeClock to progress by 1 second to restart the ADS stream. + for (int i = 0; i < 5; i++) { + // FakeClock is not thread-safe, and the retry scheduling is concurrent to this test thread + executor.submit(() -> fakeClock.forwardTime(1000, TimeUnit.MILLISECONDS)).get(); + TimeUnit.SECONDS.sleep(1); + } + + // Shouldn't do fallback since all watchers are loaded + verify(ldsWatcher, never()).onResourceChanged(StatusOr.fromValue( + XdsListenerResource.LdsUpdate.forApiListener(FALLBACK_HTTP_CONNECTION_MANAGER))); + + // Should just get from cache + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher2); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_NAME, rdsWatcher2); + verify(ldsWatcher2, timeout(5000)).onResourceChanged(StatusOr.fromValue( + XdsListenerResource.LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER))); + verify(ldsWatcher, never()).onResourceChanged(StatusOr.fromValue( + XdsListenerResource.LdsUpdate.forApiListener(FALLBACK_HTTP_CONNECTION_MANAGER))); + // Make sure that rdsWatcher wasn't called again + verify(rdsWatcher, times(1)).onResourceChanged(any()); + verify(rdsWatcher2, timeout(5000)).onResourceChanged(argThat(StatusOr::hasValue)); + + // Asking for something not in cache should force a fallback + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), FALLBACK_CLUSTER_NAME, cdsWatcher); + verify(ldsWatcher, timeout(5000)).onResourceChanged(StatusOr.fromValue( + XdsListenerResource.LdsUpdate.forApiListener(FALLBACK_HTTP_CONNECTION_MANAGER))); + verify(ldsWatcher2, timeout(5000)).onResourceChanged(StatusOr.fromValue( + XdsListenerResource.LdsUpdate.forApiListener(FALLBACK_HTTP_CONNECTION_MANAGER))); + verify(cdsWatcher, timeout(5000)).onResourceChanged(argThat(StatusOr::hasValue)); + + xdsClient.watchXdsResource( + XdsRouteConfigureResource.getInstance(), FALLBACK_RDS_NAME, rdsWatcher3); + verify(rdsWatcher3, timeout(5000)).onResourceChanged(argThat(StatusOr::hasValue)); + + // Test that resource defined in main but not fallback is handled correctly + xdsClient.watchXdsResource( + XdsClusterResource.getInstance(), CLUSTER_NAME, cdsWatcher2); + verify(cdsWatcher2, never()).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.NOT_FOUND)); + fakeClock.forwardTime(15000, TimeUnit.MILLISECONDS); // Does not exist timer + verify(cdsWatcher2, timeout(5000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.NOT_FOUND + && statusOr.getStatus().getDescription().contains(CLUSTER_NAME))); + xdsClient.shutdown(); + executor.shutdown(); + } + + @Test + public void connect_then_mainServerRestart_fallbackServerdown() { + mainXdsServer.restartXdsServer(); + xdsClient = xdsClientPool.getObject(); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + + verify(ldsWatcher, timeout(5000)).onResourceChanged( + argThat(statusOr -> statusOr.hasValue() && statusOr.getValue().equals( + LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER)))); + mainXdsServer.getServer().shutdownNow(); + fallbackServer.getServer().shutdownNow(); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CLUSTER_NAME, cdsWatcher); + mainXdsServer.restartXdsServer(); + + verify(cdsWatcher, timeout(5000)).onResourceChanged( + argThat(statusOr -> statusOr.hasValue())); + verify(ldsWatcher, timeout(5000).atLeastOnce()).onResourceChanged( + argThat(statusOr -> statusOr.hasValue() && statusOr.getValue().equals( + LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER)))); + } + + @Test + public void fallbackFromBadUrlToGoodOne() { + // Setup xdsClient to fail on stream creation + String garbageUri = "some. garbage"; + + String validUri = "localhost:" + mainXdsServer.getServer().getPort(); + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(garbageUri, validUri), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); + + client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + fakeClock.forwardTime(20, TimeUnit.SECONDS); + verify(ldsWatcher, timeout(5000)).onResourceChanged( + StatusOr.fromValue(XdsListenerResource.LdsUpdate.forApiListener( + MAIN_HTTP_CONNECTION_MANAGER))); + verify(ldsWatcher, never()).onAmbientError(any(Status.class)); + + client.shutdown(); + } + + @Test + public void testGoodUrlFollowedByBadUrl() { + // xdsClient should succeed in stream creation as it doesn't need to use the bad url + String garbageUri = "some. garbage"; + String validUri = "localhost:" + mainXdsServer.getServer().getPort(); + + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(validUri, garbageUri), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); + + client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + verify(ldsWatcher, timeout(5000)).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOr = ldsUpdateCaptor.getValue(); + assertThat(statusOr.hasValue()).isTrue(); + assertThat(statusOr.getValue()).isEqualTo( + XdsListenerResource.LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER)); + verify(ldsWatcher, never()).onAmbientError(any()); + verify(ldsWatcher, times(1)).onResourceChanged(any()); + + client.shutdown(); + } + + @Test + public void testTwoBadUrl() { + // Setup xdsClient to fail on stream creation + String garbageUri1 = "some. garbage"; + String garbageUri2 = "other garbage"; + + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(garbageUri1, garbageUri2), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); + + client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + fakeClock.forwardTime(20, TimeUnit.SECONDS); + verify(ldsWatcher, Mockito.timeout(5000).atLeastOnce()) + .onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOr = ldsUpdateCaptor.getValue(); + assertThat(statusOr.hasValue()).isFalse(); + assertThat(statusOr.getStatus().getDescription()).contains(garbageUri2); + verify(ldsWatcher, never()).onResourceChanged(argThat(StatusOr::hasValue)); + client.shutdown(); + } + + private Bootstrapper.ServerInfo getLrsServerInfo(String target) { + for (Map.Entry entry + : xdsClient.getServerLrsClientMap().entrySet()) { + if (entry.getKey().target().equals(target)) { + return entry.getKey(); + } + } + return null; + } + + @Test + public void used_then_mainServerRestart_fallbackServerUp() { + xdsClient = xdsClientPool.getObject(); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + + verify(ldsWatcher, timeout(5000)).onResourceChanged( + StatusOr.fromValue(LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER))); + + mainXdsServer.restartXdsServer(); + + assertThat(getLrsServerInfo("localhost:" + fallbackServer.getServer().getPort())).isNull(); + assertThat(getLrsServerInfo("localhost:" + mainXdsServer.getServer().getPort())).isNotNull(); + + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CLUSTER_NAME, cdsWatcher); + + verify(cdsWatcher, timeout(5000)).onResourceChanged(any()); + assertThat(getLrsServerInfo("localhost:" + fallbackServer.getServer().getPort())).isNull(); + } + + private Map defaultBootstrapOverride() { + return ImmutableMap.of( + "node", ImmutableMap.of( + "id", UUID.randomUUID().toString(), + "cluster", CLUSTER_NAME), + "xds_servers", ImmutableList.of( + ImmutableMap.of( + "server_uri", "localhost:" + mainXdsServer.getServer().getPort(), + "channel_creds", Collections.singletonList( + ImmutableMap.of("type", "insecure") + ), + "server_features", Collections.singletonList("xds_v3") + ), + ImmutableMap.of( + "server_uri", "localhost:" + fallbackServer.getServer().getPort(), + "channel_creds", Collections.singletonList( + ImmutableMap.of("type", "insecure") + ), + "server_features", Collections.singletonList("xds_v3") + ) + ), + "fallback-policy", "fallback" + ); + } +} diff --git a/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java b/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java index 0b8e89de721..da310871c25 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java @@ -17,12 +17,18 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.grpc.MetricRecorder; +import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.internal.ObjectPool; import io.grpc.xds.Filter.NamedFilterConfig; import io.grpc.xds.XdsListenerResource.LdsUpdate; @@ -45,6 +51,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -73,12 +80,15 @@ public class XdsClientFederationTest { private ObjectPool xdsClientPool; private XdsClient xdsClient; private static final String DUMMY_TARGET = "dummy"; + private final MetricRecorder metricRecorder = new MetricRecorder() {}; @Before public void setUp() throws XdsInitializationException { SharedXdsClientPoolProvider clientPoolProvider = new SharedXdsClientPoolProvider(); - clientPoolProvider.setBootstrapOverride(defaultBootstrapOverride()); - xdsClientPool = clientPoolProvider.getOrCreate(DUMMY_TARGET); + xdsClientPool = clientPoolProvider.getOrCreate( + DUMMY_TARGET, + new GrpcBootstrapperImpl().bootstrap(defaultBootstrapOverride()), + metricRecorder); xdsClient = xdsClientPool.getObject(); } @@ -103,14 +113,19 @@ public void isolatedResourceDeletions() { xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "xdstp://server-one/envoy.config.listener.v3.Listener/test-server", mockDirectPathWatcher); - verify(mockWatcher, timeout(10000)).onChanged( - LdsUpdate.forApiListener( - HttpConnectionManager.forRdsName(0, "route-config.googleapis.com", ImmutableList.of( - new NamedFilterConfig("terminal-filter", RouterFilter.ROUTER_CONFIG))))); - verify(mockDirectPathWatcher, timeout(10000)).onChanged( - LdsUpdate.forApiListener( - HttpConnectionManager.forRdsName(0, "route-config.googleapis.com", ImmutableList.of( - new NamedFilterConfig("terminal-filter", RouterFilter.ROUTER_CONFIG))))); + @SuppressWarnings("unchecked") + ArgumentCaptor> captor = ArgumentCaptor.forClass(StatusOr.class); + LdsUpdate expectedUpdate = LdsUpdate.forApiListener( + HttpConnectionManager.forRdsName(0, "route-config.googleapis.com", ImmutableList.of( + new NamedFilterConfig("terminal-filter", RouterFilter.ROUTER_CONFIG)))); + + verify(mockWatcher, timeout(10000)).onResourceChanged(captor.capture()); + assertThat(captor.getValue().hasValue()).isTrue(); + assertThat(captor.getValue().getValue()).isEqualTo(expectedUpdate); + + verify(mockDirectPathWatcher, timeout(10000)).onResourceChanged(captor.capture()); + assertThat(captor.getValue().hasValue()).isTrue(); + assertThat(captor.getValue().getValue()).isEqualTo(expectedUpdate); // By setting the LDS config with a new server name we effectively make the old server to go // away as it is not in the configuration anymore. This change in one control plane (here the @@ -118,9 +133,13 @@ public void isolatedResourceDeletions() { // watcher of another control plane (here the DirectPath one). trafficdirector.setLdsConfig(ControlPlaneRule.buildServerListener(), ControlPlaneRule.buildClientListener("new-server")); - verify(mockWatcher, timeout(20000)).onResourceDoesNotExist("test-server"); - verify(mockDirectPathWatcher, times(0)).onResourceDoesNotExist( - "xdstp://server-one/envoy.config.listener.v3.Listener/test-server"); + verify(mockWatcher, timeout(20000)).onResourceChanged(argThat(statusOr -> { + return !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.NOT_FOUND + && statusOr.getStatus().getDescription().contains("test-server"); + })); + verify(mockDirectPathWatcher, times(1)).onResourceChanged(any()); + verify(mockDirectPathWatcher, never()).onAmbientError(any()); } /** @@ -152,7 +171,6 @@ public void lrsClientsStartedForLocalityStats() throws InterruptedException, Exe } } - /** * Assures that when an {@link XdsClient} is asked to add cluster locality stats it appropriately * starts {@link LoadReportClient}s to do that. diff --git a/xds/src/test/java/io/grpc/xds/XdsClientMetricReporterImplTest.java b/xds/src/test/java/io/grpc/xds/XdsClientMetricReporterImplTest.java new file mode 100644 index 00000000000..509a0025b7b --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/XdsClientMetricReporterImplTest.java @@ -0,0 +1,412 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import com.google.protobuf.Any; +import io.envoyproxy.envoy.config.listener.v3.Listener; +import io.grpc.MetricInstrument; +import io.grpc.MetricRecorder; +import io.grpc.MetricRecorder.BatchCallback; +import io.grpc.MetricRecorder.BatchRecorder; +import io.grpc.MetricSink; +import io.grpc.xds.XdsClientMetricReporterImpl.MetricReporterCallback; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceMetadata; +import io.grpc.xds.client.XdsClient.ServerConnectionCallback; +import io.grpc.xds.client.XdsResourceType; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; +import org.mockito.Captor; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** + * Unit tests for {@link XdsClientMetricReporterImpl}. + */ +@RunWith(JUnit4.class) +public class XdsClientMetricReporterImplTest { + + private static final String target = "test-target"; + private static final String authority = "test-authority"; + private static final String server = "trafficdirector.googleapis.com"; + private static final String resourceTypeUrl = + "resourceTypeUrl.googleapis.com/envoy.config.cluster.v3.Cluster"; + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private XdsClient mockXdsClient; + @Captor + private ArgumentCaptor gaugeBatchCallbackCaptor; + private MetricRecorder mockMetricRecorder = mock(MetricRecorder.class, + delegatesTo(new MetricRecorderImpl())); + private BatchRecorder mockBatchRecorder = mock(BatchRecorder.class, + delegatesTo(new BatchRecorderImpl())); + + private XdsClientMetricReporterImpl reporter; + + @Before + public void setUp() { + reporter = new XdsClientMetricReporterImpl(mockMetricRecorder, target); + } + + @Test + public void reportResourceUpdates() { + reporter.reportResourceUpdates(10, 5, server, resourceTypeUrl); + verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.xds_client.resource_updates_valid"), eq((long) 10), + eq(Lists.newArrayList(target, server, resourceTypeUrl)), + eq(Lists.newArrayList())); + verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.xds_client.resource_updates_invalid"), + eq((long) 5), + eq(Lists.newArrayList(target, server, resourceTypeUrl)), + eq(Lists.newArrayList())); + } + + @Test + public void reportServerFailure() { + reporter.reportServerFailure(1, server); + verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.xds_client.server_failure"), eq((long) 1), + eq(Lists.newArrayList(target, server)), + eq(Lists.newArrayList())); + } + + @Test + public void setXdsClient_reportMetrics() throws Exception { + SettableFuture future = SettableFuture.create(); + future.set(null); + when(mockXdsClient.getSubscribedResourcesMetadataSnapshot()).thenReturn(Futures.immediateFuture( + ImmutableMap.of())); + when(mockXdsClient.reportServerConnections(any(ServerConnectionCallback.class))) + .thenReturn(future); + reporter.setXdsClient(mockXdsClient); + verify(mockMetricRecorder).registerBatchCallback(gaugeBatchCallbackCaptor.capture(), + eqMetricInstrumentName("grpc.xds_client.connected"), + eqMetricInstrumentName("grpc.xds_client.resources")); + gaugeBatchCallbackCaptor.getValue().accept(mockBatchRecorder); + verify(mockXdsClient).reportServerConnections(any(ServerConnectionCallback.class)); + } + + @Test + public void setXdsClient_reportCallbackMetrics_resourceCountsFails() { + TestlogHandler testLogHandler = new TestlogHandler(); + Logger logger = Logger.getLogger(XdsClientMetricReporterImpl.class.getName()); + logger.addHandler(testLogHandler); + + // For reporting resource counts connections, return a normally completed future + SettableFuture future = SettableFuture.create(); + future.set(null); + when(mockXdsClient.getSubscribedResourcesMetadataSnapshot()).thenReturn(Futures.immediateFuture( + ImmutableMap.of())); + + // Create a future that will throw an exception + SettableFuture serverConnectionsFeature = SettableFuture.create(); + serverConnectionsFeature.setException(new Exception("test")); + when(mockXdsClient.reportServerConnections(any())).thenReturn(serverConnectionsFeature); + + reporter.setXdsClient(mockXdsClient); + verify(mockMetricRecorder) + .registerBatchCallback(gaugeBatchCallbackCaptor.capture(), any(), any()); + gaugeBatchCallbackCaptor.getValue().accept(mockBatchRecorder); + // Verify that the xdsClient methods were called + // verify(mockXdsClient).reportResourceCounts(any()); + verify(mockXdsClient).reportServerConnections(any()); + + assertThat(testLogHandler.getLogs().size()).isEqualTo(1); + assertThat(testLogHandler.getLogs().get(0).getLevel()).isEqualTo(Level.WARNING); + assertThat(testLogHandler.getLogs().get(0).getMessage()).isEqualTo( + "Failed to report gauge metrics"); + logger.removeHandler(testLogHandler); + } + + @Test + public void metricGauges() { + SettableFuture future = SettableFuture.create(); + future.set(null); + when(mockXdsClient.getSubscribedResourcesMetadataSnapshot()) + .thenReturn(Futures.immediateFuture(ImmutableMap.of())); + when(mockXdsClient.reportServerConnections(any(ServerConnectionCallback.class))) + .thenReturn(future); + reporter.setXdsClient(mockXdsClient); + verify(mockMetricRecorder).registerBatchCallback(gaugeBatchCallbackCaptor.capture(), + eqMetricInstrumentName("grpc.xds_client.connected"), + eqMetricInstrumentName("grpc.xds_client.resources")); + BatchCallback gaugeBatchCallback = gaugeBatchCallbackCaptor.getValue(); + InOrder inOrder = inOrder(mockBatchRecorder); + // Trigger the internal call to reportCallbackMetrics() + gaugeBatchCallback.accept(mockBatchRecorder); + + ArgumentCaptor serverConnectionCallbackCaptor = + ArgumentCaptor.forClass(ServerConnectionCallback.class); + // verify(mockXdsClient).reportResourceCounts(resourceCallbackCaptor.capture()); + verify(mockXdsClient).reportServerConnections(serverConnectionCallbackCaptor.capture()); + + // Get the captured callback + MetricReporterCallback callback = (MetricReporterCallback) + serverConnectionCallbackCaptor.getValue(); + + // Verify that reportResourceCounts and reportServerConnections were called + // with the captured callback + callback.reportResourceCountGauge(10, "MrPotatoHead", + "acked", resourceTypeUrl); + inOrder.verify(mockBatchRecorder) + .recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), eq(10L), any(), + any()); + callback.reportServerConnectionGauge(true, "xdsServer"); + inOrder.verify(mockBatchRecorder) + .recordLongGauge(eqMetricInstrumentName("grpc.xds_client.connected"), + eq(1L), any(), any()); + + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void metricReporterCallback() { + MetricReporterCallback callback = + new MetricReporterCallback(mockBatchRecorder, target); + + callback.reportServerConnectionGauge(true, server); + verify(mockBatchRecorder, times(1)).recordLongGauge( + eqMetricInstrumentName("grpc.xds_client.connected"), eq(1L), + eq(Lists.newArrayList(target, server)), + eq(Lists.newArrayList())); + + String cacheState = "requested"; + callback.reportResourceCountGauge(10, authority, cacheState, resourceTypeUrl); + verify(mockBatchRecorder, times(1)).recordLongGauge( + eqMetricInstrumentName("grpc.xds_client.resources"), eq(10L), + eq(Arrays.asList(target, authority, cacheState, resourceTypeUrl)), + eq(Collections.emptyList())); + } + + @Test + public void reportCallbackMetrics_computeAndReportResourceCounts() { + Map, Map> metadataByType = new HashMap<>(); + XdsResourceType listenerResource = XdsListenerResource.getInstance(); + XdsResourceType routeConfigResource = XdsRouteConfigureResource.getInstance(); + XdsResourceType clusterResource = XdsClusterResource.getInstance(); + + Any rawListener = Any.pack(Listener.newBuilder().setName("listener.googleapis.com").build()); + long nanosLastUpdate = 1577923199_606042047L; + + Map ldsResourceMetadataMap = new HashMap<>(); + ldsResourceMetadataMap.put("xdstp://authority1", + ResourceMetadata.newResourceMetadataRequested()); + ResourceMetadata ackedLdsResource = + ResourceMetadata.newResourceMetadataAcked(rawListener, "42", nanosLastUpdate); + ldsResourceMetadataMap.put("resource2", ackedLdsResource); + ldsResourceMetadataMap.put("resource3", + ResourceMetadata.newResourceMetadataAcked(rawListener, "43", nanosLastUpdate)); + ldsResourceMetadataMap.put("xdstp:/no_authority", + ResourceMetadata.newResourceMetadataNacked(ackedLdsResource, "44", + nanosLastUpdate, "nacked after previous ack", true)); + + Map rdsResourceMetadataMap = new HashMap<>(); + ResourceMetadata requestedRdsResourceMetadata = ResourceMetadata.newResourceMetadataRequested(); + rdsResourceMetadataMap.put("xdstp://authority5", + ResourceMetadata.newResourceMetadataNacked(requestedRdsResourceMetadata, "24", + nanosLastUpdate, "nacked after request", false)); + rdsResourceMetadataMap.put("xdstp://authority6", + ResourceMetadata.newResourceMetadataDoesNotExist()); + + Map cdsResourceMetadataMap = new HashMap<>(); + cdsResourceMetadataMap.put("xdstp://authority7", ResourceMetadata.newResourceMetadataUnknown()); + + metadataByType.put(listenerResource, ldsResourceMetadataMap); + metadataByType.put(routeConfigResource, rdsResourceMetadataMap); + metadataByType.put(clusterResource, cdsResourceMetadataMap); + + SettableFuture reportServerConnectionsCompleted = SettableFuture.create(); + reportServerConnectionsCompleted.set(null); + when(mockXdsClient.reportServerConnections(any(MetricReporterCallback.class))) + .thenReturn(reportServerConnectionsCompleted); + + ListenableFuture, Map>> + getResourceMetadataCompleted = Futures.immediateFuture(metadataByType); + when(mockXdsClient.getSubscribedResourcesMetadataSnapshot()) + .thenReturn(getResourceMetadataCompleted); + + reporter.reportCallbackMetrics(mockBatchRecorder, mockXdsClient); + + // LDS resource requested + verify(mockBatchRecorder).recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), + eq(1L), + eq(Arrays.asList(target, "authority1", "requested", listenerResource.typeUrl())), any()); + // LDS resources acked + // authority = #old, for non-xdstp resource names + verify(mockBatchRecorder).recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), + eq(2L), + eq(Arrays.asList(target, "#old", "acked", listenerResource.typeUrl())), any()); + // LDS resource nacked but cached + // "" for missing authority in the resource name + verify(mockBatchRecorder).recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), + eq(1L), + eq(Arrays.asList(target, "", "nacked_but_cached", listenerResource.typeUrl())), any()); + + // RDS resource nacked + verify(mockBatchRecorder).recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), + eq(1L), + eq(Arrays.asList(target, "authority5", "nacked", routeConfigResource.typeUrl())), any()); + // RDS resource does not exist + verify(mockBatchRecorder).recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), + eq(1L), + eq(Arrays.asList(target, "authority6", "does_not_exist", routeConfigResource.typeUrl())), + any()); + + // CDS resource unknown + verify(mockBatchRecorder).recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), + eq(1L), + eq(Arrays.asList(target, "authority7", "unknown", clusterResource.typeUrl())), + any()); + verifyNoMoreInteractions(mockBatchRecorder); + } + + @Test + public void reportCallbackMetrics_computeAndReportResourceCounts_emptyResources() { + Map, Map> metadataByType = new HashMap<>(); + XdsResourceType listenerResource = XdsListenerResource.getInstance(); + metadataByType.put(listenerResource, Collections.emptyMap()); + + SettableFuture reportServerConnectionsCompleted = SettableFuture.create(); + reportServerConnectionsCompleted.set(null); + when(mockXdsClient.reportServerConnections(any(MetricReporterCallback.class))) + .thenReturn(reportServerConnectionsCompleted); + + ListenableFuture, Map>> + getResourceMetadataCompleted = Futures.immediateFuture(metadataByType); + when(mockXdsClient.getSubscribedResourcesMetadataSnapshot()) + .thenReturn(getResourceMetadataCompleted); + + reporter.reportCallbackMetrics(mockBatchRecorder, mockXdsClient); + + // Verify that reportResourceCountGauge is never called + verifyNoInteractions(mockBatchRecorder); + } + + @Test + public void reportCallbackMetrics_computeAndReportResourceCounts_nullMetadata() { + TestlogHandler testLogHandler = new TestlogHandler(); + Logger logger = Logger.getLogger(XdsClientMetricReporterImpl.class.getName()); + logger.addHandler(testLogHandler); + + SettableFuture reportServerConnectionsCompleted = SettableFuture.create(); + reportServerConnectionsCompleted.set(null); + when(mockXdsClient.reportServerConnections(any(MetricReporterCallback.class))) + .thenReturn(reportServerConnectionsCompleted); + + ListenableFuture, Map>> + getResourceMetadataCompleted = Futures.immediateFailedFuture( + new Exception("Error generating metadata snapshot")); + when(mockXdsClient.getSubscribedResourcesMetadataSnapshot()) + .thenReturn(getResourceMetadataCompleted); + + reporter.reportCallbackMetrics(mockBatchRecorder, mockXdsClient); + assertThat(testLogHandler.getLogs().size()).isEqualTo(1); + assertThat(testLogHandler.getLogs().get(0).getLevel()).isEqualTo(Level.WARNING); + assertThat(testLogHandler.getLogs().get(0).getMessage()).isEqualTo( + "Failed to report gauge metrics"); + logger.removeHandler(testLogHandler); + } + + @Test + public void close_closesGaugeRegistration() { + MetricSink.Registration mockRegistration = mock(MetricSink.Registration.class); + when(mockMetricRecorder.registerBatchCallback(any(MetricRecorder.BatchCallback.class), + eqMetricInstrumentName("grpc.xds_client.connected"), + eqMetricInstrumentName("grpc.xds_client.resources"))).thenReturn(mockRegistration); + + // Sets XdsClient and register the gauges + reporter.setXdsClient(mockXdsClient); + // Closes registered gauges + reporter.close(); + verify(mockRegistration, times(1)).close(); + } + + @SuppressWarnings("TypeParameterUnusedInFormals") + private T eqMetricInstrumentName(String name) { + return argThat(new ArgumentMatcher() { + @Override + public boolean matches(T instrument) { + return instrument.getName().equals(name); + } + }); + } + + static class MetricRecorderImpl implements MetricRecorder { + } + + static class BatchRecorderImpl implements BatchRecorder { + } + + static class TestlogHandler extends Handler { + List logs = new ArrayList<>(); + + @Override + public void publish(LogRecord record) { + logs.add(record); + } + + @Override + public void close() {} + + @Override + public void flush() {} + + public List getLogs() { + return logs; + } + } + +} diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java index f3f4d74eb2f..81186d0639c 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java @@ -32,9 +32,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.SettableFuture; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.inprocess.InProcessSocketAddress; import io.grpc.internal.TestUtils.NoopChannelLogger; import io.grpc.netty.GrpcHttp2ConnectionHandler; @@ -119,7 +121,8 @@ public void setUp() { when(mockBuilder.build()).thenReturn(mockServer); when(mockServer.isShutdown()).thenReturn(false); xdsServerWrapper = new XdsServerWrapper("0.0.0.0:" + PORT, mockBuilder, listener, - selectorManager, new FakeXdsClientPoolFactory(xdsClient), FilterRegistry.newRegistry()); + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, FilterRegistry.newRegistry()); } @Test @@ -165,11 +168,12 @@ public void run() { EnvoyServerProtoData.Listener tcpListener = EnvoyServerProtoData.Listener.create( "listener1", - "10.1.2.3", + "0.0.0.0:7000", ImmutableList.of(), - null); + null, + Protocol.TCP); LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(tcpListener); - xdsClient.ldsWatcher.onChanged(listenerUpdate); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromValue(listenerUpdate)); verify(listener, timeout(5000)).onServing(); start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); FilterChainSelector selector = selectorManager.getSelectorToUpdateSelector(); @@ -190,7 +194,8 @@ public void run() { } }); String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - xdsClient.ldsWatcher.onResourceDoesNotExist(ldsWatched); + Status status = Status.NOT_FOUND.withDescription("Resource not found: " + ldsWatched); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(status)); verify(listener, timeout(5000)).onNotServing(any()); try { start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); @@ -275,7 +280,8 @@ public void releaseOldSupplierOnNotFound_verifyClose() throws Exception { getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); callUpdateSslContext(returnedSupplier); - xdsClient.ldsWatcher.onResourceDoesNotExist("not-found Error"); + Status status = Status.NOT_FOUND.withDescription("not-found Error"); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(status)); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); } @@ -292,14 +298,14 @@ public void releaseOldSupplierOnTemporaryError_noClose() throws Exception { getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); callUpdateSslContext(returnedSupplier); - xdsClient.ldsWatcher.onError(Status.CANCELLED); + xdsClient.ldsWatcher.onAmbientError(Status.CANCELLED); verify(tlsContextManager, never()).releaseServerSslContextProvider(eq(sslContextProvider1)); } private void callUpdateSslContext(SslContextProviderSupplier sslContextProviderSupplier) { assertThat(sslContextProviderSupplier).isNotNull(); SslContextProvider.Callback callback = mock(SslContextProvider.Callback.class); - sslContextProviderSupplier.updateSslContext(callback); + sslContextProviderSupplier.updateSslContext(callback, false); } private void sendListenerUpdate( diff --git a/xds/src/test/java/io/grpc/xds/XdsDependencyManagerTest.java b/xds/src/test/java/io/grpc/xds/XdsDependencyManagerTest.java new file mode 100644 index 00000000000..7bae7000eaf --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/XdsDependencyManagerTest.java @@ -0,0 +1,1033 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.StatusMatcher.statusHasCode; +import static io.grpc.xds.XdsClusterResource.CdsUpdate.ClusterType.EDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; +import static io.grpc.xds.XdsTestUtils.CLUSTER_NAME; +import static io.grpc.xds.XdsTestUtils.ENDPOINT_HOSTNAME; +import static io.grpc.xds.XdsTestUtils.ENDPOINT_PORT; +import static io.grpc.xds.XdsTestUtils.RDS_NAME; +import static io.grpc.xds.XdsTestUtils.getEdsNameForCluster; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.protobuf.Message; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; +import io.envoyproxy.envoy.config.listener.v3.Listener; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; +import io.grpc.BindableService; +import io.grpc.ChannelLogger; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.NameResolverRegistry; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.StatusOrMatcher; +import io.grpc.SynchronizationContext; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.FakeClock; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.testing.FakeNameResolverProvider; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsConfig.XdsClusterConfig; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.client.Locality; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceMetadata; +import io.grpc.xds.client.XdsResourceType; +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; +import org.mockito.ArgumentMatchers; +import org.mockito.Captor; +import org.mockito.InOrder; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Unit tests for {@link XdsDependencyManager}. */ +@RunWith(JUnit4.class) +public class XdsDependencyManagerTest { + private static final Logger log = Logger.getLogger(XdsDependencyManagerTest.class.getName()); + public static final String CLUSTER_TYPE_NAME = XdsClusterResource.getInstance().typeName(); + public static final String ENDPOINT_TYPE_NAME = XdsEndpointResource.getInstance().typeName(); + + private final SynchronizationContext syncContext = + new SynchronizationContext((t, e) -> { + throw new AssertionError(e); + }); + private final FakeClock fakeClock = new FakeClock(); + + private XdsClient xdsClient = XdsTestUtils.createXdsClient( + Collections.singletonList("control-plane"), + serverInfo -> new GrpcXdsTransportFactory.GrpcXdsTransport( + InProcessChannelBuilder.forName(serverInfo.target()).directExecutor().build()), + fakeClock); + + private TestWatcher xdsConfigWatcher; + + private final String serverName = "the-service-name"; + private final Queue loadReportCalls = new ArrayDeque<>(); + private final AtomicBoolean adsEnded = new AtomicBoolean(true); + private final AtomicBoolean lrsEnded = new AtomicBoolean(true); + private final XdsTestControlPlaneService controlPlaneService = new XdsTestControlPlaneService(); + private final BindableService lrsService = + XdsTestUtils.createLrsService(lrsEnded, loadReportCalls); + private final NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + + @Rule + public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + private TestWatcher testWatcher; + private XdsConfig defaultXdsConfig; // set in setUp() + + @Captor + private ArgumentCaptor> xdsUpdateCaptor; + private final NameResolver.Args nameResolverArgs = NameResolver.Args.newBuilder() + .setDefaultPort(8080) + .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR) + .setSynchronizationContext(syncContext) + .setServiceConfigParser(mock(NameResolver.ServiceConfigParser.class)) + .setChannelLogger(mock(ChannelLogger.class)) + .setScheduledExecutorService(fakeClock.getScheduledExecutorService()) + .setNameResolverRegistry(nameResolverRegistry) + .build(); + + private XdsDependencyManager xdsDependencyManager = new XdsDependencyManager( + xdsClient, syncContext, serverName, serverName, nameResolverArgs); + private boolean savedEnableLogicalDns; + + @Before + public void setUp() throws Exception { + cleanupRule.register(InProcessServerBuilder + .forName("control-plane") + .addService(controlPlaneService) + .addService(lrsService) + .directExecutor() + .build() + .start()); + + XdsTestUtils.setAdsConfig(controlPlaneService, serverName); + + testWatcher = new TestWatcher(); + xdsConfigWatcher = mock(TestWatcher.class, delegatesTo(testWatcher)); + defaultXdsConfig = XdsTestUtils.getDefaultXdsConfig(serverName); + + savedEnableLogicalDns = XdsDependencyManager.enableLogicalDns; + } + + @After + public void tearDown() throws InterruptedException { + if (xdsDependencyManager != null) { + xdsDependencyManager.shutdown(); + } + xdsClient.shutdown(); + + assertThat(adsEnded.get()).isTrue(); + assertThat(lrsEnded.get()).isTrue(); + assertThat(fakeClock.getPendingTasks()).isEmpty(); + + XdsDependencyManager.enableLogicalDns = savedEnableLogicalDns; + } + + @Test + public void verify_basic_config() { + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(StatusOr.fromValue(defaultXdsConfig)); + testWatcher.verifyStats(1, 0); + } + + @Test + public void verify_config_update() { + xdsDependencyManager.start(xdsConfigWatcher); + + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(StatusOr.fromValue(defaultXdsConfig)); + testWatcher.verifyStats(1, 0); + assertThat(testWatcher.lastConfig).isEqualTo(defaultXdsConfig); + + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS2", "CDS2", "EDS2", + ENDPOINT_HOSTNAME + "2", ENDPOINT_PORT + 2); + inOrder.verify(xdsConfigWatcher).onUpdate(ArgumentMatchers.notNull()); + testWatcher.verifyStats(2, 0); + assertThat(testWatcher.lastConfig).isNotEqualTo(defaultXdsConfig); + } + + @Test + public void verify_simple_aggregate() { + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(StatusOr.fromValue(defaultXdsConfig)); + + List childNames = Arrays.asList("clusterC", "clusterB", "clusterA"); + String rootName = "root_c"; + + RouteConfiguration routeConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, rootName); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + + XdsTestUtils.setAggregateCdsConfig(controlPlaneService, serverName, rootName, childNames); + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + + Map> lastConfigClusters = + testWatcher.lastConfig.getClusters(); + assertThat(lastConfigClusters).hasSize(childNames.size() + 1); + StatusOr rootC = lastConfigClusters.get(rootName); + assertThat(rootC.getValue().getChildren()).isInstanceOf(XdsClusterConfig.AggregateConfig.class); + XdsClusterConfig.AggregateConfig aggConfig = + (XdsClusterConfig.AggregateConfig) rootC.getValue().getChildren(); + assertThat(aggConfig.getLeafNames()).isEqualTo(childNames); + + for (String childName : childNames) { + assertThat(lastConfigClusters).containsKey(childName); + StatusOr childConfigOr = lastConfigClusters.get(childName); + CdsUpdate childResource = + childConfigOr.getValue().getClusterResource(); + assertThat(childResource.clusterType()).isEqualTo(EDS); + assertThat(childResource.edsServiceName()).isEqualTo(getEdsNameForCluster(childName)); + + StatusOr endpoint = getEndpoint(childConfigOr); + assertThat(endpoint.hasValue()).isTrue(); + assertThat(endpoint.getValue().clusterName).isEqualTo(getEdsNameForCluster(childName)); + } + } + + private static StatusOr getEndpoint(StatusOr childConfigOr) { + XdsClusterConfig.ClusterChild clusterChild = childConfigOr.getValue() + .getChildren(); + assertThat(clusterChild).isInstanceOf(XdsClusterConfig.EndpointConfig.class); + StatusOr endpoint = ((XdsClusterConfig.EndpointConfig) clusterChild).getEndpoint(); + assertThat(endpoint).isNotNull(); + return endpoint; + } + + @Test + public void testComplexRegisteredAggregate() throws IOException { + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + + // Do initialization + String rootName1 = "root_c"; + List childNames = Arrays.asList("clusterC", "clusterB", "clusterA"); + XdsTestUtils.addAggregateToExistingConfig(controlPlaneService, rootName1, childNames); + + String rootName2 = "root_2"; + List childNames2 = Arrays.asList("clusterA", "clusterX"); + XdsTestUtils.addAggregateToExistingConfig(controlPlaneService, rootName2, childNames2); + + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + + Closeable subscription1 = xdsDependencyManager.subscribeToCluster(rootName1); + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + + Closeable subscription2 = xdsDependencyManager.subscribeToCluster(rootName2); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + testWatcher.verifyStats(3, 0); + ImmutableSet.Builder builder = ImmutableSet.builder(); + Set expectedClusters = builder.add(rootName1).add(rootName2).add(CLUSTER_NAME) + .addAll(childNames).addAll(childNames2).build(); + assertThat(xdsUpdateCaptor.getValue().getValue().getClusters().keySet()) + .isEqualTo(expectedClusters); + + // Close 1 subscription shouldn't affect the other or RDS subscriptions + subscription1.close(); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + builder = ImmutableSet.builder(); + Set expectedClusters2 = + builder.add(rootName2).add(CLUSTER_NAME).addAll(childNames2).build(); + assertThat(xdsUpdateCaptor.getValue().getValue().getClusters().keySet()) + .isEqualTo(expectedClusters2); + + subscription2.close(); + inOrder.verify(xdsConfigWatcher).onUpdate(StatusOr.fromValue(defaultXdsConfig)); + } + + @Test + public void testDelayedSubscription() { + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(StatusOr.fromValue(defaultXdsConfig)); + + String rootName1 = "root_c"; + + Closeable subscription1 = xdsDependencyManager.subscribeToCluster(rootName1); + assertThat(subscription1).isNotNull(); + fakeClock.forwardTime(16, TimeUnit.SECONDS); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + Status status = xdsUpdateCaptor.getValue().getValue().getClusters().get(rootName1).getStatus(); + assertThat(status.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(status.getDescription()).contains(rootName1); + + List childNames = Arrays.asList("clusterC", "clusterB", "clusterA"); + XdsTestUtils.addAggregateToExistingConfig(controlPlaneService, rootName1, childNames); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + assertThat(xdsUpdateCaptor.getValue().getValue().getClusters().get(rootName1).hasValue()) + .isTrue(); + } + + @Test + public void testMissingCdsAndEds() { + // update config so that agg cluster references 2 existing & 1 non-existing cluster + List childNames = Arrays.asList("clusterC", "clusterB", "clusterA"); + Cluster cluster = XdsTestUtils.buildAggCluster(CLUSTER_NAME, childNames); + Map clusterMap = new HashMap<>(); + Map edsMap = new HashMap<>(); + + clusterMap.put(CLUSTER_NAME, cluster); + for (int i = 0; i < childNames.size() - 1; i++) { + String edsName = XdsTestUtils.EDS_NAME + i; + Cluster child = ControlPlaneRule.buildCluster(childNames.get(i), edsName); + clusterMap.put(childNames.get(i), child); + } + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + + // Update config so that one of the 2 "valid" clusters has an EDS resource, the other does not + // and there is an EDS that doesn't have matching clusters + ClusterLoadAssignment clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.1.1", ENDPOINT_HOSTNAME, ENDPOINT_PORT, XdsTestUtils.EDS_NAME + 0); + edsMap.put(XdsTestUtils.EDS_NAME + 0, clusterLoadAssignment); + clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.1.2", ENDPOINT_HOSTNAME, ENDPOINT_PORT, "garbageEds"); + edsMap.put("garbageEds", clusterLoadAssignment); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + xdsDependencyManager.start(xdsConfigWatcher); + + fakeClock.forwardTime(16, TimeUnit.SECONDS); + verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + + List> returnedClusters = new ArrayList<>(); + for (String childName : childNames) { + returnedClusters.add(xdsUpdateCaptor.getValue().getValue().getClusters().get(childName)); + } + + // Check that missing cluster reported Status and the other 2 are present + StatusOr missingCluster = returnedClusters.get(2); + assertThat(missingCluster.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(missingCluster.getStatus().getDescription()).contains(childNames.get(2)); + assertThat(returnedClusters.get(0).hasValue()).isTrue(); + assertThat(returnedClusters.get(1).hasValue()).isTrue(); + + // Check that missing EDS reported Status, the other one is present and the garbage EDS is not + assertThat(getEndpoint(returnedClusters.get(0)).hasValue()).isTrue(); + assertThat(getEndpoint(returnedClusters.get(1)).getStatus().getCode()) + .isEqualTo(Status.Code.UNAVAILABLE); + assertThat(getEndpoint(returnedClusters.get(1)).getStatus().getDescription()) + .contains(XdsTestUtils.EDS_NAME + 1); + + verify(xdsConfigWatcher, never()).onUpdate( + argThat(StatusOrMatcher.hasStatus(statusHasCode(Status.Code.UNAVAILABLE)))); + testWatcher.verifyStats(1, 0); + } + + @Test + public void testMissingLds() { + String ldsName = "badLdsName"; + xdsDependencyManager = new XdsDependencyManager(xdsClient, syncContext, + serverName, ldsName, nameResolverArgs); + xdsDependencyManager.start(xdsConfigWatcher); + + fakeClock.forwardTime(16, TimeUnit.SECONDS); + verify(xdsConfigWatcher).onUpdate( + argThat(StatusOrMatcher.hasStatus(statusHasCode(Status.Code.UNAVAILABLE) + .andDescriptionContains(ldsName)))); + + testWatcher.verifyStats(0, 1); + } + + @Test + public void testTcpListenerErrors() { + Listener serverListener = + ControlPlaneRule.buildServerListener().toBuilder().setName(serverName).build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, ImmutableMap.of(serverName, serverListener)); + xdsDependencyManager.start(xdsConfigWatcher); + + fakeClock.forwardTime(16, TimeUnit.SECONDS); + verify(xdsConfigWatcher).onUpdate( + argThat(StatusOrMatcher.hasStatus( + statusHasCode(Status.Code.UNAVAILABLE).andDescriptionContains("Not an API listener")))); + + testWatcher.verifyStats(0, 1); + } + + @Test + public void testMissingRds() { + String rdsName = "badRdsName"; + Listener clientListener = ControlPlaneRule.buildClientListener(serverName, rdsName); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, + ImmutableMap.of(serverName, clientListener)); + + xdsDependencyManager.start(xdsConfigWatcher); + + fakeClock.forwardTime(16, TimeUnit.SECONDS); + verify(xdsConfigWatcher).onUpdate( + argThat(StatusOrMatcher.hasStatus(statusHasCode(Status.Code.UNAVAILABLE) + .andDescriptionContains(rdsName)))); + + testWatcher.verifyStats(0, 1); + } + + @Test + public void testUpdateToMissingVirtualHost() { + RouteConfiguration routeConfig = XdsTestUtils.buildRouteConfiguration( + "wrong-virtual-host", XdsTestUtils.RDS_NAME, XdsTestUtils.CLUSTER_NAME); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + xdsDependencyManager.start(xdsConfigWatcher); + + // Update with a config that has a virtual host that doesn't match the server name + verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + assertThat(xdsUpdateCaptor.getValue().getStatus().getDescription()) + .contains("Failed to find virtual host matching hostname: " + serverName); + + testWatcher.verifyStats(0, 1); + } + + @Test + public void testCorruptLds() { + String ldsResourceName = + "xdstp://unknown.example.com/envoy.config.listener.v3.Listener/listener1"; + + xdsDependencyManager = new XdsDependencyManager(xdsClient, syncContext, + serverName, ldsResourceName, nameResolverArgs); + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate( + argThat(StatusOrMatcher.hasStatus( + statusHasCode(Status.Code.UNAVAILABLE).andDescriptionContains(ldsResourceName)))); + + fakeClock.forwardTime(16, TimeUnit.SECONDS); + testWatcher.verifyStats(0, 1); + } + + @Test + public void testChangeRdsName_fromLds() { + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(StatusOr.fromValue(defaultXdsConfig)); + + String newRdsName = "newRdsName1"; + + Listener clientListener = buildInlineClientListener(newRdsName, CLUSTER_NAME); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, + ImmutableMap.of(serverName, clientListener)); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + assertThat(xdsUpdateCaptor.getValue().getValue()).isNotEqualTo(defaultXdsConfig); + assertThat(xdsUpdateCaptor.getValue().getValue().getVirtualHost().name()).isEqualTo(newRdsName); + } + + @Test + public void testMultipleParentsInCdsTree() throws IOException { + /* + * Configure Xds server with the following cluster tree and point RDS to root: + 2 aggregates under root A & B + B has EDS Cluster B1 && shared agg AB1; A has agg A1 && shared agg AB1 + A1 has shared EDS Cluster A11 && shared agg AB1 + AB1 has shared EDS Clusters A11 && AB11 + + As an alternate visualization, parents are: + A -> root, B -> root, A1 -> A, AB1 -> A|B|A1, B1 -> B, A11 -> A1|AB1, AB11 -> AB1 + */ + Cluster rootCluster = + XdsTestUtils.buildAggCluster("root", Arrays.asList("clusterA", "clusterB")); + Cluster clusterA = + XdsTestUtils.buildAggCluster("clusterA", Arrays.asList("clusterA1", "clusterAB1")); + Cluster clusterB = + XdsTestUtils.buildAggCluster("clusterB", Arrays.asList("clusterB1", "clusterAB1")); + Cluster clusterA1 = + XdsTestUtils.buildAggCluster("clusterA1", Arrays.asList("clusterA11", "clusterAB1")); + Cluster clusterAB1 = + XdsTestUtils.buildAggCluster("clusterAB1", Arrays.asList("clusterA11", "clusterAB11")); + + Map clusterMap = new HashMap<>(); + Map edsMap = new HashMap<>(); + + clusterMap.put("root", rootCluster); + clusterMap.put("clusterA", clusterA); + clusterMap.put("clusterB", clusterB); + clusterMap.put("clusterA1", clusterA1); + clusterMap.put("clusterAB1", clusterAB1); + + XdsTestUtils.addEdsClusters(clusterMap, edsMap, "clusterA11", "clusterAB11", "clusterB1"); + RouteConfiguration routeConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, "root"); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + // Start the actual test + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig initialConfig = xdsUpdateCaptor.getValue().getValue(); + + // Make sure that adding subscriptions that rds points at doesn't change the config + Closeable rootSub = xdsDependencyManager.subscribeToCluster("root"); + assertThat(xdsDependencyManager.buildUpdate().getValue()).isEqualTo(initialConfig); + Closeable clusterAB11Sub = xdsDependencyManager.subscribeToCluster("clusterAB11"); + assertThat(xdsDependencyManager.buildUpdate().getValue()).isEqualTo(initialConfig); + + // Make sure that closing subscriptions that rds points at doesn't change the config + rootSub.close(); + assertThat(xdsDependencyManager.buildUpdate().getValue()).isEqualTo(initialConfig); + clusterAB11Sub.close(); + assertThat(xdsDependencyManager.buildUpdate().getValue()).isEqualTo(initialConfig); + + // Make an explicit root subscription and then change RDS to point to A11 + rootSub = xdsDependencyManager.subscribeToCluster("root"); + RouteConfiguration newRouteConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, "clusterA11"); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, newRouteConfig)); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + assertThat(xdsUpdateCaptor.getValue().getValue().getClusters()).hasSize(8); + + // Now that it is released, we should only have A11 + rootSub.close(); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + assertThat(xdsUpdateCaptor.getValue().getValue().getClusters().keySet()) + .containsExactly("clusterA11"); + } + + @Test + public void testCdsDeleteUnsubscribesChild() throws Exception { + RouteConfiguration routeConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, "clusterA"); + Map clusterMap = new HashMap<>(); + Map edsMap = new HashMap<>(); + XdsTestUtils.addEdsClusters(clusterMap, edsMap, "clusterA"); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig config = xdsUpdateCaptor.getValue().getValue(); + assertThat(config.getClusters().get("clusterA").hasValue()).isTrue(); + Map, Map> watches = + xdsClient.getSubscribedResourcesMetadataSnapshot().get(); + assertThat(watches.get(XdsEndpointResource.getInstance()).keySet()) + .containsExactly("eds_clusterA"); + + // Delete cluster + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + config = xdsUpdateCaptor.getValue().getValue(); + assertThat(config.getClusters().get("clusterA").hasValue()).isFalse(); + watches = xdsClient.getSubscribedResourcesMetadataSnapshot().get(); + assertThat(watches).doesNotContainKey(XdsEndpointResource.getInstance()); + } + + @Test + public void testCdsCycleReclaimed() throws Exception { + RouteConfiguration routeConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, "clusterA"); + Map clusterMap = new HashMap<>(); + Map edsMap = new HashMap<>(); + clusterMap.put("clusterA", XdsTestUtils.buildAggCluster("clusterA", Arrays.asList("clusterB"))); + clusterMap.put("clusterB", XdsTestUtils.buildAggCluster("clusterB", Arrays.asList("clusterA"))); + XdsTestUtils.addEdsClusters(clusterMap, edsMap, "clusterC"); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + // The cycle is loaded and detected + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig config = xdsUpdateCaptor.getValue().getValue(); + assertThat(config.getClusters().get("clusterA").hasValue()).isFalse(); + assertThat(config.getClusters().get("clusterA").getStatus().getDescription()).contains("cycle"); + assertThat(config.getClusters().get("clusterB").hasValue()).isTrue(); + + // Orphan the cycle and it is discarded + routeConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, "clusterC"); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + Map, Map> watches = + xdsClient.getSubscribedResourcesMetadataSnapshot().get(); + assertThat(watches.get(XdsClusterResource.getInstance()).keySet()).containsExactly("clusterC"); + } + + @Test + public void testMultipleCdsReferToSameEds() { + // Create the maps and Update the config to have 2 clusters that refer to the same EDS resource + String edsName = "sharedEds"; + + Cluster rootCluster = + XdsTestUtils.buildAggCluster("root", Arrays.asList("clusterA", "clusterB")); + Cluster clusterA = ControlPlaneRule.buildCluster("clusterA", edsName); + Cluster clusterB = ControlPlaneRule.buildCluster("clusterB", edsName); + + Map clusterMap = new HashMap<>(); + clusterMap.put("root", rootCluster); + clusterMap.put("clusterA", clusterA); + clusterMap.put("clusterB", clusterB); + + Map edsMap = new HashMap<>(); + ClusterLoadAssignment clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.1.4", ENDPOINT_HOSTNAME, ENDPOINT_PORT, edsName); + edsMap.put(edsName, clusterLoadAssignment); + + RouteConfiguration routeConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, "root"); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + // Start the actual test + xdsDependencyManager.start(xdsConfigWatcher); + verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig initialConfig = xdsUpdateCaptor.getValue().getValue(); + assertThat(initialConfig.getClusters().keySet()) + .containsExactly("root", "clusterA", "clusterB"); + + EdsUpdate edsForA = getEndpoint(initialConfig.getClusters().get("clusterA")).getValue(); + assertThat(edsForA.clusterName).isEqualTo(edsName); + EdsUpdate edsForB = getEndpoint(initialConfig.getClusters().get("clusterB")).getValue(); + assertThat(edsForB.clusterName).isEqualTo(edsName); + assertThat(edsForA).isEqualTo(edsForB); + edsForA.localityLbEndpointsMap.values().forEach( + localityLbEndpoints -> assertThat(localityLbEndpoints.endpoints()).hasSize(1)); + } + + @Test + public void testChangeRdsName_FromLds_complexTree() { + xdsDependencyManager.start(xdsConfigWatcher); + + // Create the same tree as in testMultipleParentsInCdsTree + Cluster rootCluster = + XdsTestUtils.buildAggCluster("root", Arrays.asList("clusterA", "clusterB")); + Cluster clusterA = + XdsTestUtils.buildAggCluster("clusterA", Arrays.asList("clusterA1", "clusterAB1")); + Cluster clusterB = + XdsTestUtils.buildAggCluster("clusterB", Arrays.asList("clusterB1", "clusterAB1")); + Cluster clusterA1 = + XdsTestUtils.buildAggCluster("clusterA1", Arrays.asList("clusterA11", "clusterAB1")); + Cluster clusterAB1 = + XdsTestUtils.buildAggCluster("clusterAB1", Arrays.asList("clusterA11", "clusterAB11")); + + Map clusterMap = new HashMap<>(); + Map edsMap = new HashMap<>(); + + clusterMap.put("root", rootCluster); + clusterMap.put("clusterA", clusterA); + clusterMap.put("clusterB", clusterB); + clusterMap.put("clusterA1", clusterA1); + clusterMap.put("clusterAB1", clusterAB1); + + XdsTestUtils.addEdsClusters(clusterMap, edsMap, "clusterA11", "clusterAB11", "clusterB1"); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher, atLeastOnce()).onUpdate(any()); + + // Do the test + String newRdsName = "newRdsName1"; + Listener clientListener = buildInlineClientListener(newRdsName, "root"); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, + ImmutableMap.of(serverName, clientListener)); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig config = xdsUpdateCaptor.getValue().getValue(); + assertThat(config.getVirtualHost().name()).isEqualTo(newRdsName); + assertThat(config.getClusters()).hasSize(8); + } + + @Test + public void testChangeAggCluster() { + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + + // Setup initial config A -> A1 -> (A11, A12) + Cluster rootCluster = + XdsTestUtils.buildAggCluster("root", Arrays.asList("clusterA")); + Cluster clusterA = + XdsTestUtils.buildAggCluster("clusterA", Arrays.asList("clusterA1")); + Cluster clusterA1 = + XdsTestUtils.buildAggCluster("clusterA1", Arrays.asList("clusterA11", "clusterA12")); + + Map clusterMap = new HashMap<>(); + Map edsMap = new HashMap<>(); + + clusterMap.put("root", rootCluster); + clusterMap.put("clusterA", clusterA); + clusterMap.put("clusterA1", clusterA1); + + XdsTestUtils.addEdsClusters(clusterMap, edsMap, "clusterA11", "clusterA12"); + Listener clientListener = buildInlineClientListener(RDS_NAME, "root"); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, + ImmutableMap.of(serverName, clientListener)); + + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + + // Update the cluster to A -> A2 -> (A21, A22) + Cluster clusterA2 = + XdsTestUtils.buildAggCluster("clusterA2", Arrays.asList("clusterA21", "clusterA22")); + clusterA = + XdsTestUtils.buildAggCluster("clusterA", Arrays.asList("clusterA2")); + clusterMap.clear(); + edsMap.clear(); + clusterMap.put("root", rootCluster); + clusterMap.put("clusterA", clusterA); + clusterMap.put("clusterA2", clusterA2); + XdsTestUtils.addEdsClusters(clusterMap, edsMap, "clusterA21", "clusterA22"); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + // Verify that the config is updated as expected + ClusterNameMatcher nameMatcher = new ClusterNameMatcher(Arrays.asList( + "root", "clusterA", "clusterA2", "clusterA21", "clusterA22")); + inOrder.verify(xdsConfigWatcher).onUpdate(argThat(nameMatcher)); + } + + @Test + public void testLogicalDns_success() { + XdsDependencyManager.enableLogicalDns = true; + FakeSocketAddress fakeAddress = new FakeSocketAddress(); + nameResolverRegistry.register(new FakeNameResolverProvider( + "dns:///dns.example.com:1111", fakeAddress)); + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER_NAME) + .setType(Cluster.DiscoveryType.LOGICAL_DNS) + .setLoadAssignment(ClusterLoadAssignment.newBuilder() + .addEndpoints(LocalityLbEndpoints.newBuilder() + .addLbEndpoints(LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("dns.example.com") + .setPortValue(1111))))))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, + ImmutableMap.of(CLUSTER_NAME, cluster)); + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig config = xdsUpdateCaptor.getValue().getValue(); + XdsClusterConfig.ClusterChild clusterChild = + config.getClusters().get(CLUSTER_NAME).getValue().getChildren(); + assertThat(clusterChild).isInstanceOf(XdsClusterConfig.EndpointConfig.class); + StatusOr endpointOr = ((XdsClusterConfig.EndpointConfig) clusterChild).getEndpoint(); + assertThat(endpointOr.getStatus()).isEqualTo(Status.OK); + assertThat(endpointOr.getValue()).isEqualTo(new EdsUpdate( + "fakeEds_logicalDns", + ImmutableMap.of( + Locality.create("", "", ""), + Endpoints.LocalityLbEndpoints.create( + Arrays.asList(Endpoints.LbEndpoint.create( + new EquivalentAddressGroup(fakeAddress), + 1, true, "dns.example.com:1111", ImmutableMap.of())), + 1, 0, ImmutableMap.of())), + Arrays.asList())); + } + + @Test + public void testLogicalDns_noDnsNr() { + XdsDependencyManager.enableLogicalDns = true; + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER_NAME) + .setType(Cluster.DiscoveryType.LOGICAL_DNS) + .setLoadAssignment(ClusterLoadAssignment.newBuilder() + .addEndpoints(LocalityLbEndpoints.newBuilder() + .addLbEndpoints(LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("dns.example.com") + .setPortValue(1111))))))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, + ImmutableMap.of(CLUSTER_NAME, cluster)); + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig config = xdsUpdateCaptor.getValue().getValue(); + XdsClusterConfig.ClusterChild clusterChild = + config.getClusters().get(CLUSTER_NAME).getValue().getChildren(); + assertThat(clusterChild).isInstanceOf(XdsClusterConfig.EndpointConfig.class); + StatusOr endpointOr = ((XdsClusterConfig.EndpointConfig) clusterChild).getEndpoint(); + assertThat(endpointOr.getStatus().getCode()).isEqualTo(Status.Code.INTERNAL); + assertThat(endpointOr.getStatus().getDescription()) + .isEqualTo("Could not find dns name resolver"); + } + + @Test + public void testCdsError() throws IOException { + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_CDS, ImmutableMap.of(XdsTestUtils.CLUSTER_NAME, + Cluster.newBuilder().setName(XdsTestUtils.CLUSTER_NAME).build())); + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + Status status = xdsUpdateCaptor.getValue().getValue() + .getClusters().get(CLUSTER_NAME).getStatus(); + assertThat(status.getDescription()).contains(XdsTestUtils.CLUSTER_NAME); + } + + @Test + public void ldsUpdateAfterShutdown() { + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(any()); + + @SuppressWarnings("unchecked") + XdsClient.ResourceWatcher resourceWatcher = + mock(XdsClient.ResourceWatcher.class); + xdsClient.watchXdsResource( + XdsListenerResource.getInstance(), + serverName, + resourceWatcher, + MoreExecutors.directExecutor()); + verify(resourceWatcher).onResourceChanged(argThat(StatusOr::hasValue)); + + syncContext.execute(() -> { + // Shutdown before any updates. This will unsubscribe from XdsClient, but only after this + // Runnable returns + xdsDependencyManager.shutdown(); + + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS2", "CDS", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + verify(resourceWatcher, times(2)).onResourceChanged(argThat(StatusOr::hasValue)); + xdsClient.cancelXdsResourceWatch( + XdsListenerResource.getInstance(), serverName, resourceWatcher); + }); + } + + @Test + public void rdsUpdateAfterShutdown() { + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(any()); + + @SuppressWarnings("unchecked") + XdsClient.ResourceWatcher resourceWatcher = + mock(XdsClient.ResourceWatcher.class); + xdsClient.watchXdsResource( + XdsRouteConfigureResource.getInstance(), + "RDS", + resourceWatcher, + MoreExecutors.directExecutor()); + verify(resourceWatcher).onResourceChanged(argThat(StatusOr::hasValue)); + + syncContext.execute(() -> { + // Shutdown before any updates. This will unsubscribe from XdsClient, but only after this + // Runnable returns + xdsDependencyManager.shutdown(); + + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS2", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + verify(resourceWatcher, times(2)).onResourceChanged(argThat(StatusOr::hasValue)); + xdsClient.cancelXdsResourceWatch( + XdsRouteConfigureResource.getInstance(), serverName, resourceWatcher); + }); + } + + @Test + public void cdsUpdateAfterShutdown() { + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(any()); + + @SuppressWarnings("unchecked") + XdsClient.ResourceWatcher resourceWatcher = + mock(XdsClient.ResourceWatcher.class); + xdsClient.watchXdsResource( + XdsClusterResource.getInstance(), + "CDS", + resourceWatcher, + MoreExecutors.directExecutor()); + verify(resourceWatcher).onResourceChanged(argThat(StatusOr::hasValue)); + + syncContext.execute(() -> { + // Shutdown before any updates. This will unsubscribe from XdsClient, but only after this + // Runnable returns + xdsDependencyManager.shutdown(); + + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS2", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + verify(resourceWatcher, times(2)).onResourceChanged(argThat(StatusOr::hasValue)); + xdsClient.cancelXdsResourceWatch( + XdsClusterResource.getInstance(), serverName, resourceWatcher); + }); + } + + @Test + public void edsUpdateAfterShutdown() { + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(any()); + + @SuppressWarnings("unchecked") + XdsClient.ResourceWatcher resourceWatcher = + mock(XdsClient.ResourceWatcher.class); + xdsClient.watchXdsResource( + XdsEndpointResource.getInstance(), + "EDS", + resourceWatcher, + MoreExecutors.directExecutor()); + verify(resourceWatcher).onResourceChanged(argThat(StatusOr::hasValue)); + + syncContext.execute(() -> { + // Shutdown before any updates. This will unsubscribe from XdsClient, but only after this + // Runnable returns + xdsDependencyManager.shutdown(); + + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS", + ENDPOINT_HOSTNAME + "2", ENDPOINT_PORT); + verify(resourceWatcher, times(2)).onResourceChanged(argThat(StatusOr::hasValue)); + xdsClient.cancelXdsResourceWatch( + XdsEndpointResource.getInstance(), serverName, resourceWatcher); + }); + } + + @Test + public void subscribeToClusterAfterShutdown() throws Exception { + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + xdsDependencyManager.shutdown(); + + Closeable subscription = xdsDependencyManager.subscribeToCluster("CDS"); + inOrder.verify(xdsConfigWatcher, never()).onUpdate(any()); + subscription.close(); + } + + private Listener buildInlineClientListener(String rdsName, String clusterName) { + return XdsTestUtils.buildInlineClientListener(rdsName, clusterName, serverName); + } + + private static class TestWatcher implements XdsDependencyManager.XdsConfigWatcher { + XdsConfig lastConfig; + int numUpdates = 0; + int numError = 0; + + @Override + public void onUpdate(StatusOr update) { + log.fine("Config update: " + update); + if (update.hasValue()) { + lastConfig = update.getValue(); + numUpdates++; + } else { + numError++; + } + } + + private List getStats() { + return Arrays.asList(numUpdates, numError); + } + + private void verifyStats(int updt, int err) { + assertThat(getStats()).isEqualTo(Arrays.asList(updt, err)); + } + } + + static class ClusterNameMatcher implements ArgumentMatcher> { + private final List expectedNames; + + ClusterNameMatcher(List expectedNames) { + this.expectedNames = expectedNames; + } + + @Override + public boolean matches(StatusOr update) { + if (!update.hasValue()) { + return false; + } + XdsConfig xdsConfig = update.getValue(); + if (xdsConfig == null || xdsConfig.getClusters() == null) { + return false; + } + return xdsConfig.getClusters().size() == expectedNames.size() + && xdsConfig.getClusters().keySet().containsAll(expectedNames); + } + } + + private static class FakeSocketAddress extends java.net.SocketAddress {} +} diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java index a216c3de028..8998a2bae99 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java @@ -23,23 +23,28 @@ import com.google.common.collect.ImmutableMap; import io.grpc.ChannelLogger; import io.grpc.InternalServiceProviders; +import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; import io.grpc.SynchronizationContext; +import io.grpc.Uri; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; import java.net.URI; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** Unit tests for {@link XdsNameResolverProvider}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class XdsNameResolverProviderTest { private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @@ -57,10 +62,18 @@ public void uncaughtException(Thread t, Throwable e) { .setServiceConfigParser(mock(ServiceConfigParser.class)) .setScheduledExecutorService(fakeClock.getScheduledExecutorService()) .setChannelLogger(mock(ChannelLogger.class)) + .setMetricRecorder(mock(MetricRecorder.class)) .build(); private XdsNameResolverProvider provider = new XdsNameResolverProvider(); + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + @Test public void provided() { for (NameResolverProvider current @@ -79,48 +92,46 @@ public void isAvailable() { } @Test - public void newNameResolver() { - assertThat( - provider.newNameResolver(URI.create("xds://1.1.1.1/foo.googleapis.com"), args)) + public void newNameResolver_returnsExpectedType() { + assertThat(newNameResolver(provider, "xds://1.1.1.1/foo.googleapis.com", args)) .isInstanceOf(XdsNameResolver.class); - assertThat( - provider.newNameResolver(URI.create("xds:///foo.googleapis.com"), args)) + assertThat(newNameResolver(provider, "xds:///foo.googleapis.com", args)) .isInstanceOf(XdsNameResolver.class); - assertThat( - provider.newNameResolver(URI.create("notxds://1.1.1.1/foo.googleapis.com"), - args)) - .isNull(); + } + + @Test + public void newNameResolver_matchesExpectedScheme() { + assertThat(newNameResolver(provider, "notxds://1.1.1.1/foo.googleapis.com", args)).isNull(); } @Test public void validName_withAuthority() { - XdsNameResolver resolver = - provider.newNameResolver( - URI.create("xds://trafficdirector.google.com/foo.googleapis.com"), args); + NameResolver resolver = + newNameResolver(provider, "xds://trafficdirector.google.com/foo.googleapis.com", args); assertThat(resolver).isNotNull(); assertThat(resolver.getServiceAuthority()).isEqualTo("foo.googleapis.com"); } @Test public void validName_noAuthority() { - XdsNameResolver resolver = - provider.newNameResolver(URI.create("xds:///foo.googleapis.com"), args); + NameResolver resolver = newNameResolver(provider, "xds:///foo.googleapis.com", args); assertThat(resolver).isNotNull(); assertThat(resolver.getServiceAuthority()).isEqualTo("foo.googleapis.com"); } @Test public void validName_urlExtractedAuthorityInvalidWithoutEncoding() { - XdsNameResolver resolver = - provider.newNameResolver(URI.create("xds:///1234/path/foo.googleapis.com:8080"), args); + NameResolver resolver = + newNameResolver(provider, "xds:///1234/path/foo.googleapis.com:8080", args); assertThat(resolver).isNotNull(); assertThat(resolver.getServiceAuthority()).isEqualTo("1234%2Fpath%2Ffoo.googleapis.com:8080"); } @Test public void validName_urlwithTargetAuthorityAndExtractedAuthorityInvalidWithoutEncoding() { - XdsNameResolver resolver = provider.newNameResolver(URI.create( - "xds://trafficdirector.google.com/1234/path/foo.googleapis.com:8080"), args); + NameResolver resolver = + newNameResolver( + provider, "xds://trafficdirector.google.com/1234/path/foo.googleapis.com:8080", args); assertThat(resolver).isNotNull(); assertThat(resolver.getServiceAuthority()).isEqualTo("1234%2Fpath%2Ffoo.googleapis.com:8080"); } @@ -133,18 +144,14 @@ public void newProvider_multipleScheme() { XdsNameResolverProvider provider1 = XdsNameResolverProvider.createForTest("new-xds-scheme", new HashMap()); registry.register(provider1); - assertThat(registry.asFactory() - .newNameResolver(URI.create("xds:///localhost"), args)).isNotNull(); - assertThat(registry.asFactory() - .newNameResolver(URI.create("new-xds-scheme:///localhost"), args)).isNotNull(); - assertThat(registry.asFactory() - .newNameResolver(URI.create("no-scheme:///localhost"), args)).isNotNull(); + assertThat(newNameResolver(registry.asFactory(), "xds:///localhost", args)).isNotNull(); + assertThat(newNameResolver(registry.asFactory(), "new-xds-scheme:///localhost", args)) + .isNotNull(); + assertThat(newNameResolver(registry.asFactory(), "no-scheme:///localhost", args)).isNotNull(); registry.deregister(provider1); - assertThat(registry.asFactory() - .newNameResolver(URI.create("new-xds-scheme:///localhost"), args)).isNull(); + assertThat(newNameResolver(registry.asFactory(), "new-xds-scheme:///localhost", args)).isNull(); registry.deregister(provider0); - assertThat(registry.asFactory() - .newNameResolver(URI.create("xds:///localhost"), args)).isNotNull(); + assertThat(newNameResolver(registry.asFactory(), "xds:///localhost", args)).isNotNull(); } @Test @@ -174,4 +181,11 @@ public void newProvider_overrideBootstrap() { resolver.shutdown(); registry.deregister(provider); } + + private NameResolver newNameResolver( + NameResolver.Factory factory, String uriString, NameResolver.Args args) { + return enableRfc3986UrisParam + ? factory.newNameResolver(Uri.create(uriString), args) + : factory.newNameResolver(URI.create(uriString), args); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index 24c2a43b83a..e78f97635ed 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -17,15 +17,19 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; import static io.grpc.xds.FaultFilter.HEADER_ABORT_GRPC_STATUS_KEY; import static io.grpc.xds.FaultFilter.HEADER_ABORT_HTTP_STATUS_KEY; import static io.grpc.xds.FaultFilter.HEADER_ABORT_PERCENTAGE_KEY; import static io.grpc.xds.FaultFilter.HEADER_DELAY_KEY; import static io.grpc.xds.FaultFilter.HEADER_DELAY_PERCENTAGE_KEY; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; @@ -43,6 +47,7 @@ import com.google.re2j.Pattern; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.ChannelLogger; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; @@ -55,6 +60,7 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.MethodType; +import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; import io.grpc.NameResolver.ResolutionResult; @@ -63,9 +69,11 @@ import io.grpc.NoopClientCall.NoopClientCallListener; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.internal.AutoConfiguredLoadBalancerFactory; import io.grpc.internal.FakeClock; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonParser; import io.grpc.internal.JsonUtil; import io.grpc.internal.ObjectPool; @@ -85,6 +93,8 @@ import io.grpc.xds.VirtualHost.Route.RouteAction.RetryPolicy; import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.client.Bootstrapper.AuthorityInfo; @@ -92,14 +102,12 @@ import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.EnvoyProtoData.Node; import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.client.XdsResourceType; import java.io.IOException; -import java.net.URI; -import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -129,6 +137,17 @@ public class XdsNameResolverTest { private static final String RDS_RESOURCE_NAME = "route-configuration.googleapis.com"; private static final String FAULT_FILTER_INSTANCE_NAME = "envoy.fault"; private static final String ROUTER_FILTER_INSTANCE_NAME = "envoy.router"; + private static final FaultFilter.Provider FAULT_FILTER_PROVIDER = new FaultFilter.Provider(); + private static final RouterFilter.Provider ROUTER_FILTER_PROVIDER = new RouterFilter.Provider(); + + // Readability: makes it simpler to distinguish resource parameters. + private static final ImmutableMap NO_FILTER_OVERRIDES = ImmutableMap.of(); + private static final ImmutableList NO_HASH_POLICIES = ImmutableList.of(); + + // Stateful instance filter names. + private static final String STATEFUL_1 = "test.stateful.filter.1"; + private static final String STATEFUL_2 = "test.stateful.filter.2"; + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); private final SynchronizationContext syncContext = new SynchronizationContext( @@ -152,6 +171,16 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { private final CallInfo call1 = new CallInfo("HelloService", "hi"); private final CallInfo call2 = new CallInfo("GreetService", "bye"); private final TestChannel channel = new TestChannel(); + private final MetricRecorder metricRecorder = new MetricRecorder() {}; + private final Map rawBootstrap = ImmutableMap.of( + "xds_servers", ImmutableList.of( + ImmutableMap.of( + "server_uri", "td.googleapis.com", + "channel_creds", ImmutableList.of( + ImmutableMap.of( + "type", "insecure"))) + )); + private BootstrapInfo bootstrapInfo = BootstrapInfo.builder() .servers(ImmutableList.of(ServerInfo.create( "td.googleapis.com", InsecureChannelCredentials.create()))) @@ -170,77 +199,77 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { private XdsNameResolver resolver; private TestCall testCall; private boolean originalEnableTimeout; - private URI targetUri; + private String targetUri = AUTHORITY; + private final NameResolver.Args nameResolverArgs = NameResolver.Args.newBuilder() + .setDefaultPort(8080) + .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR) + .setSynchronizationContext(syncContext) + .setServiceConfigParser(mock(NameResolver.ServiceConfigParser.class)) + .setChannelLogger(mock(ChannelLogger.class)) + .setScheduledExecutorService(fakeClock.getScheduledExecutorService()) + .build(); + @Before public void setUp() { - try { - targetUri = new URI(AUTHORITY); - } catch (URISyntaxException e) { - targetUri = null; - } + lenient().doReturn(Status.OK).when(mockListener).onResult2(any()); originalEnableTimeout = XdsNameResolver.enableTimeout; XdsNameResolver.enableTimeout = true; + + // Replace FaultFilter.Provider with the one returning FaultFilter injected with mockRandom. + Filter.Provider faultFilterProvider = + mock(Filter.Provider.class, delegatesTo(FAULT_FILTER_PROVIDER)); + // Lenient: suppress [MockitoHint] Unused warning, only used in resolved_fault* tests. + lenient() + .doReturn(new FaultFilter(mockRandom, new AtomicLong())) + .when(faultFilterProvider).newInstance(any(String.class)); + FilterRegistry filterRegistry = FilterRegistry.newRegistry().register( - new FaultFilter(mockRandom, new AtomicLong()), - RouterFilter.INSTANCE); + ROUTER_FILTER_PROVIDER, + faultFilterProvider); + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, filterRegistry, null); + xdsClientPoolFactory, mockRandom, filterRegistry, rawBootstrap, metricRecorder, + nameResolverArgs); } @After public void tearDown() { XdsNameResolver.enableTimeout = originalEnableTimeout; + if (resolver == null) { + // Allow tests to test shutdown. + return; + } FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); resolver.shutdown(); if (xdsClient != null) { assertThat(xdsClient.ldsWatcher).isNull(); - assertThat(xdsClient.rdsWatcher).isNull(); + assertThat(xdsClient.rdsWatchers).isEmpty(); } } @Test public void resolving_failToCreateXdsClientPool() { - XdsClientPoolFactory xdsClientPoolFactory = new XdsClientPoolFactory() { - @Override - public void setBootstrapOverride(Map bootstrap) { - } - - @Override - @Nullable - public ObjectPool get(String target) { - throw new UnsupportedOperationException("Should not be called"); - } - - @Override - public ObjectPool getOrCreate(String target) throws XdsInitializationException { - throw new XdsInitializationException("Fail to read bootstrap file"); - } - - @Override - public List getTargets() { - return null; - } - }; - resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), + Collections.emptyMap(), metricRecorder, nameResolverArgs); resolver.start(mockListener); verify(mockListener).onError(errorCaptor.capture()); Status error = errorCaptor.getValue(); assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(error.getDescription()).isEqualTo("Failed to initialize xDS"); - assertThat(error.getCause()).hasMessageThat().isEqualTo("Fail to read bootstrap file"); + assertThat(error.getCause()).hasMessageThat().contains("Invalid bootstrap"); } @Test public void resolving_withTargetAuthorityNotFound() { resolver = new XdsNameResolver(targetUri, "notfound.google.com", AUTHORITY, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); verify(mockListener).onError(errorCaptor.capture()); Status error = errorCaptor.getValue(); @@ -262,7 +291,42 @@ public void resolving_noTargetAuthority_templateWithoutXdstp() { resolver = new XdsNameResolver( targetUri, null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, - mockRandom, FilterRegistry.getDefaultRegistry(), null); + mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, metricRecorder, + nameResolverArgs); + resolver.start(mockListener); + verify(mockListener, never()).onError(any(Status.class)); + } + + @Test + public void resolving_emptyTargetAuthority_templateWithXdstp() { + bootstrapInfo = + BootstrapInfo.builder() + .servers( + ImmutableList.of( + ServerInfo.create("td.googleapis.com", InsecureChannelCredentials.create()))) + .node(Node.newBuilder().build()) + .clientDefaultListenerResourceNameTemplate( + "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/%s?id=1") + .build(); + String serviceAuthority = "[::FFFF:129.144.52.38]:80"; + expectedLdsResourceName = + "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/" + + "%5B::FFFF:129.144.52.38%5D:80?id=1"; + resolver = + new XdsNameResolver( + "xds:///foo.googleapis.com", + "", + serviceAuthority, + null, + serviceConfigParser, + syncContext, + scheduler, + xdsClientPoolFactory, + mockRandom, + FilterRegistry.getDefaultRegistry(), + rawBootstrap, + metricRecorder, + nameResolverArgs); resolver.start(mockListener); verify(mockListener, never()).onError(any(Status.class)); } @@ -282,7 +346,8 @@ public void resolving_noTargetAuthority_templateWithXdstp() { + "%5B::FFFF:129.144.52.38%5D:80?id=1"; resolver = new XdsNameResolver( targetUri, null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); verify(mockListener, never()).onError(any(Status.class)); } @@ -302,7 +367,8 @@ public void resolving_noTargetAuthority_xdstpWithMultipleSlashes() { + "path/to/service?id=1"; resolver = new XdsNameResolver( targetUri, null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); // The Service Authority must be URL encoded, but unlike the LDS resource name. @@ -318,25 +384,27 @@ public void resolving_targetAuthorityInAuthoritiesMap() { String serviceAuthority = "[::FFFF:129.144.52.38]:80"; bootstrapInfo = BootstrapInfo.builder() .servers(ImmutableList.of(ServerInfo.create( - "td.googleapis.com", InsecureChannelCredentials.create(), true))) + "td.googleapis.com", InsecureChannelCredentials.create(), true, true, false, false))) .node(Node.newBuilder().build()) .authorities( ImmutableMap.of(targetAuthority, AuthorityInfo.create( "xdstp://" + targetAuthority + "/envoy.config.listener.v3.Listener/%s?foo=1&bar=2", ImmutableList.of(ServerInfo.create( - "td.googleapis.com", InsecureChannelCredentials.create(), true))))) + "td.googleapis.com", InsecureChannelCredentials.create(), + true, true, false, false))))) .build(); expectedLdsResourceName = "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/" + "%5B::FFFF:129.144.52.38%5D:80?bar=2&foo=1"; // query param canonified resolver = new XdsNameResolver(targetUri, "xds.authority.com", serviceAuthority, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); verify(mockListener, never()).onError(any(Status.class)); } @Test - public void resolving_ldsResourceNotFound() { + public void resolving_ldsResourceNotFound() { // hi resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsResourceNotFound(); @@ -348,11 +416,11 @@ public void resolving_ldsResourceNotFound() { public void resolving_ldsResourceUpdateRdsName() { Route route1 = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); Route route2 = Route.forAction(RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), null), + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), null, false), ImmutableMap.of()); bootstrapInfo = BootstrapInfo.builder() .servers(ImmutableList.of(ServerInfo.create( @@ -362,7 +430,8 @@ public void resolving_ldsResourceUpdateRdsName() { .build(); resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); // use different ldsResourceName and service authority. The virtualhost lookup should use // service authority. expectedLdsResourceName = "test-" + expectedLdsResourceName; @@ -370,32 +439,36 @@ public void resolving_ldsResourceUpdateRdsName() { resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); - assertThat(xdsClient.rdsResource).isEqualTo(RDS_RESOURCE_NAME); + assertThat(xdsClient.rdsWatchers.keySet()).containsExactly(RDS_RESOURCE_NAME); VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList(AUTHORITY), Collections.singletonList(route1), ImmutableMap.of()); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster1), (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolutionResult.class); String alternativeRdsResource = "route-configuration-alter.googleapis.com"; xdsClient.deliverLdsUpdateForRdsName(alternativeRdsResource); - assertThat(xdsClient.rdsResource).isEqualTo(alternativeRdsResource); + assertThat(xdsClient.rdsWatchers.keySet()).contains(alternativeRdsResource); virtualHost = VirtualHost.create("virtualhost-alter", Collections.singletonList(AUTHORITY), Collections.singletonList(route2), ImmutableMap.of()); xdsClient.deliverRdsUpdate(alternativeRdsResource, Collections.singletonList(virtualHost)); + createAndDeliverClusterUpdates(xdsClient, cluster2); + assertThat(xdsClient.rdsWatchers.keySet()).containsExactly(alternativeRdsResource); // Two new service config updates triggered: // - with load balancing config being able to select cluster1 and cluster2 // - with load balancing config being able to select cluster2 only - verify(mockListener, times(2)).onResult(resultCaptor.capture()); + verify(mockListener, times(3)).onResult2(resultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Arrays.asList(cluster1, cluster2), (Map) resultCaptor.getAllValues().get(0).getServiceConfig().getConfig()); @@ -418,35 +491,39 @@ public void resolving_rdsResourceNotFound() { public void resolving_ldsResourceRevokedAndAddedBack() { Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); - assertThat(xdsClient.rdsResource).isEqualTo(RDS_RESOURCE_NAME); + assertThat(xdsClient.rdsWatchers.keySet()).containsExactly(RDS_RESOURCE_NAME); VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList(AUTHORITY), Collections.singletonList(route), ImmutableMap.of()); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster1), (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); xdsClient.deliverLdsResourceNotFound(); // revoke LDS resource - assertThat(xdsClient.rdsResource).isNull(); // stop subscribing to stale RDS resource + assertThat(xdsClient.rdsWatchers.keySet()).isEmpty(); // stop subscribing to stale RDS resource assertEmptyResolutionResult(expectedLdsResourceName); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); // No name resolution result until new RDS resource update is received. Do not use stale config verifyNoInteractions(mockListener); - assertThat(xdsClient.rdsResource).isEqualTo(RDS_RESOURCE_NAME); + assertThat(xdsClient.rdsWatchers.keySet()).containsExactly(RDS_RESOURCE_NAME); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster1), (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); @@ -457,31 +534,35 @@ public void resolving_ldsResourceRevokedAndAddedBack() { public void resolving_rdsResourceRevokedAndAddedBack() { Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); - assertThat(xdsClient.rdsResource).isEqualTo(RDS_RESOURCE_NAME); + assertThat(xdsClient.rdsWatchers.keySet()).containsExactly(RDS_RESOURCE_NAME); VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList(AUTHORITY), Collections.singletonList(route), ImmutableMap.of()); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster1), (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); xdsClient.deliverRdsResourceNotFound(RDS_RESOURCE_NAME); // revoke RDS resource assertEmptyResolutionResult(RDS_RESOURCE_NAME); // Simulate management server adds back the previously used RDS resource. reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster1), (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); @@ -492,11 +573,15 @@ public void resolving_encounterErrorLdsWatcherOnly() { resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverError(Status.UNAVAILABLE.withDescription("server unreachable")); - verify(mockListener).onError(errorCaptor.capture()); - Status error = errorCaptor.getValue(); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = resolutionResultCaptor.getValue() + .getAttributes().get(InternalConfigSelector.KEY); + Result selectResult = configSelector.selectConfig( + newPickSubchannelArgs(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + Status error = selectResult.getStatus(); assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(error.getDescription()).isEqualTo("Unable to load LDS " + AUTHORITY - + ". xDS server returned: UNAVAILABLE: server unreachable"); + assertThat(error.getDescription()).contains(AUTHORITY); + assertThat(error.getDescription()).contains("server unreachable"); } @Test @@ -504,11 +589,15 @@ public void resolving_translateErrorLds() { resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverError(Status.NOT_FOUND.withDescription("server unreachable")); - verify(mockListener).onError(errorCaptor.capture()); - Status error = errorCaptor.getValue(); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = resolutionResultCaptor.getValue() + .getAttributes().get(InternalConfigSelector.KEY); + Result selectResult = configSelector.selectConfig( + newPickSubchannelArgs(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + Status error = selectResult.getStatus(); assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(error.getDescription()).isEqualTo("Unable to load LDS " + AUTHORITY - + ". xDS server returned: NOT_FOUND: server unreachable"); + assertThat(error.getDescription()).contains(AUTHORITY); + assertThat(error.getDescription()).contains("server unreachable"); assertThat(error.getCause()).isNull(); } @@ -518,15 +607,17 @@ public void resolving_encounterErrorLdsAndRdsWatchers() { FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); xdsClient.deliverError(Status.UNAVAILABLE.withDescription("server unreachable")); - verify(mockListener, times(2)).onError(errorCaptor.capture()); - Status error = errorCaptor.getAllValues().get(0); - assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(error.getDescription()).isEqualTo("Unable to load LDS " + AUTHORITY - + ". xDS server returned: UNAVAILABLE: server unreachable"); - error = errorCaptor.getAllValues().get(1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = resolutionResultCaptor.getValue() + .getAttributes().get(InternalConfigSelector.KEY); + Result selectResult = configSelector.selectConfig( + newPickSubchannelArgs(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + Status error = selectResult.getStatus(); assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(error.getDescription()).isEqualTo("Unable to load RDS " + RDS_RESOURCE_NAME - + ". xDS server returned: UNAVAILABLE: server unreachable"); + // XdsDepManager.buildUpdate doesn't allow this + // assertThat(error.getDescription()).contains(RDS_RESOURCE_NAME); + assertThat(error.getDescription()).contains(expectedLdsResourceName); + assertThat(error.getDescription()).contains("server unreachable"); } @SuppressWarnings("unchecked") @@ -534,7 +625,7 @@ public void resolving_encounterErrorLdsAndRdsWatchers() { public void resolving_matchingVirtualHostNotFound_matchingOverrideAuthority() { Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList("random"), @@ -543,11 +634,13 @@ public void resolving_matchingVirtualHostNotFound_matchingOverrideAuthority() { resolver = new XdsNameResolver(targetUri, null, AUTHORITY, "random", serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate(0L, Arrays.asList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster1), (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); @@ -557,7 +650,7 @@ public void resolving_matchingVirtualHostNotFound_matchingOverrideAuthority() { public void resolving_matchingVirtualHostNotFound_notMatchingOverrideAuthority() { Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList(AUTHORITY), @@ -566,10 +659,12 @@ public void resolving_matchingVirtualHostNotFound_notMatchingOverrideAuthority() resolver = new XdsNameResolver(targetUri, null, AUTHORITY, "random", serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); - xdsClient.deliverLdsUpdate(0L, Arrays.asList(virtualHost)); + xdsClient.deliverLdsUpdateOnly(0L, Arrays.asList(virtualHost)); + fakeClock.forwardTime(15, TimeUnit.SECONDS); assertEmptyResolutionResult("random"); } @@ -577,7 +672,8 @@ public void resolving_matchingVirtualHostNotFound_notMatchingOverrideAuthority() public void resolving_matchingVirtualHostNotFoundForOverrideAuthority() { resolver = new XdsNameResolver(targetUri, null, AUTHORITY, AUTHORITY, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate(0L, buildUnmatchedVirtualHosts()); @@ -604,11 +700,11 @@ public void resolving_matchingVirtualHostNotFoundInRdsResource() { private List buildUnmatchedVirtualHosts() { Route route1 = Route.forAction(RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); Route route2 = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); return Arrays.asList( VirtualHost.create("virtualhost-foo", Collections.singletonList("hello.googleapis.com"), @@ -625,13 +721,13 @@ public void resolved_noTimeout() { FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), null, null), // per-route timeout unset + cluster1, Collections.emptyList(), null, null, false), // per-route timeout unset ImmutableMap.of()); VirtualHost virtualHost = VirtualHost.create("does not matter", Collections.singletonList(AUTHORITY), Collections.singletonList(route), ImmutableMap.of()); xdsClient.deliverLdsUpdate(0L, Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); assertCallSelectClusterResult(call1, configSelector, cluster1, null); @@ -643,14 +739,14 @@ public void resolved_fallbackToHttpMaxStreamDurationAsTimeout() { FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), null, null), // per-route timeout unset + cluster1, Collections.emptyList(), null, null, false), // per-route timeout unset ImmutableMap.of()); VirtualHost virtualHost = VirtualHost.create("does not matter", Collections.singletonList(AUTHORITY), Collections.singletonList(route), ImmutableMap.of()); xdsClient.deliverLdsUpdate(TimeUnit.SECONDS.toNanos(5L), Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); assertCallSelectClusterResult(call1, configSelector, cluster1, 5.0); @@ -661,7 +757,8 @@ public void retryPolicyInPerMethodConfigGeneratedByResolverIsValid() { ServiceConfigParser realParser = new ScParser( true, 5, 5, new AutoConfiguredLoadBalancerFactory("pick-first")); resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, realParser, syncContext, - scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), + rawBootstrap, metricRecorder, nameResolverArgs); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); RetryPolicy retryPolicy = RetryPolicy.create( @@ -675,9 +772,10 @@ public void retryPolicyInPerMethodConfigGeneratedByResolverIsValid() { cluster1, Collections.emptyList(), null, - retryPolicy), + retryPolicy, + false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); Result selectResult = configSelector.selectConfig( @@ -733,16 +831,16 @@ public void resolved_simpleCallFailedToRoute_routeWithNonForwardingAction() { Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster(cluster2, Collections.emptyList(), - TimeUnit.SECONDS.toNanos(15L), null), + TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster2), (Map) result.getServiceConfig().getConfig()); - assertThat(result.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL)).isNotNull(); - assertThat(result.getAttributes().get(InternalXdsAttributes.CALL_COUNTER_PROVIDER)).isNotNull(); + assertThat(result.getAttributes().get(XdsAttributes.XDS_CLIENT)).isNotNull(); + assertThat(result.getAttributes().get(XdsAttributes.CALL_COUNTER_PROVIDER)).isNotNull(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); // Simulates making a call1 RPC. Result selectResult = configSelector.selectConfig( @@ -769,9 +867,10 @@ public void resolved_rpcHashingByHeader_withoutSubstitution() { Collections.singletonList( HashPolicy.forHeader(false, "custom-key", null, null)), null, - null), + null, + false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); InternalConfigSelector configSelector = resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); @@ -801,11 +900,13 @@ public void resolved_rpcHashingByHeader_withSubstitution() { RouteAction.forCluster( cluster1, Collections.singletonList( - HashPolicy.forHeader(false, "custom-key", Pattern.compile("value"), "val")), + HashPolicy.forHeader(false, "custom-key", Pattern.compile("value"), + "val")), + null, null, - null), + false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); InternalConfigSelector configSelector = resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); @@ -842,9 +943,10 @@ public void resolved_rpcHashingByChannelId() { cluster1, Collections.singletonList(HashPolicy.forChannelId(false)), null, - null), + null, + false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); InternalConfigSelector configSelector = resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); @@ -863,10 +965,12 @@ public void resolved_rpcHashingByChannelId() { // A different resolver/Channel. resolver.shutdown(); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); when(mockRandom.nextLong()).thenReturn(123L); resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate( @@ -878,9 +982,10 @@ public void resolved_rpcHashingByChannelId() { cluster1, Collections.singletonList(HashPolicy.forChannelId(false)), null, - null), + null, + false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); configSelector = resolutionResultCaptor.getValue().getAttributes().get( InternalConfigSelector.KEY); @@ -894,6 +999,68 @@ public void resolved_rpcHashingByChannelId() { assertThat(hash3).isNotEqualTo(hash1); } + @Test + public void resolved_routeActionHasAutoHostRewrite_emitsCallOptionForTheSame() { + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, + syncContext, scheduler, xdsClientPoolFactory, mockRandom, + FilterRegistry.getDefaultRegistry(), rawBootstrap, metricRecorder, nameResolverArgs); + resolver.start(mockListener); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + xdsClient.deliverLdsUpdate( + Collections.singletonList( + Route.forAction( + RouteMatch.withPathExactOnly( + "/" + TestMethodDescriptors.voidMethod().getFullMethodName()), + RouteAction.forCluster( + cluster1, + Collections.singletonList( + HashPolicy.forHeader(false, "custom-key", null, null)), + null, + null, + true), + ImmutableMap.of()))); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = + resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); + + // First call, with header "custom-key": "custom-value". + startNewCall(TestMethodDescriptors.voidMethod(), configSelector, + ImmutableMap.of("custom-key", "custom-value"), CallOptions.DEFAULT); + + assertThat(testCall.callOptions.getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY)).isTrue(); + } + + @Test + public void resolved_routeActionNoAutoHostRewrite_doesntEmitCallOptionForTheSame() { + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, + syncContext, scheduler, xdsClientPoolFactory, mockRandom, + FilterRegistry.getDefaultRegistry(), rawBootstrap, metricRecorder, nameResolverArgs); + resolver.start(mockListener); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + xdsClient.deliverLdsUpdate( + Collections.singletonList( + Route.forAction( + RouteMatch.withPathExactOnly( + "/" + TestMethodDescriptors.voidMethod().getFullMethodName()), + RouteAction.forCluster( + cluster1, + Collections.singletonList( + HashPolicy.forHeader(false, "custom-key", null, null)), + null, + null, + false), + ImmutableMap.of()))); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = + resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); + + // First call, with header "custom-key": "custom-value". + startNewCall(TestMethodDescriptors.voidMethod(), configSelector, + ImmutableMap.of("custom-key", "custom-value"), CallOptions.DEFAULT); + + assertThat(testCall.callOptions.getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY)).isNull(); + } + @SuppressWarnings("unchecked") @Test public void resolved_resourceUpdateAfterCallStarted() { @@ -902,6 +1069,7 @@ public void resolved_resourceUpdateAfterCallStarted() { TestCall firstCall = testCall; reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate( Arrays.asList( @@ -909,15 +1077,15 @@ public void resolved_resourceUpdateAfterCallStarted() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( "another-cluster", Collections.emptyList(), - TimeUnit.SECONDS.toNanos(20L), null), + TimeUnit.SECONDS.toNanos(20L), null, false), ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); // Updated service config still contains cluster1 while it is removed resource. New calls no // longer routed to cluster1. @@ -929,7 +1097,9 @@ public void resolved_resourceUpdateAfterCallStarted() { assertCallSelectClusterResult(call1, configSelector, "another-cluster", 20.0); firstCall.deliverErrorStatus(); // completes previous call - verify(mockListener, times(2)).onResult(resolutionResultCaptor.capture()); + // Two updates: one for XdsNameResolver releasing the cluster, and another for + // XdsDependencyManager updating the XdsConfig + verify(mockListener, times(3)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); assertServiceConfigForLoadBalancingConfig( Arrays.asList(cluster2, "another-cluster"), @@ -942,6 +1112,7 @@ public void resolved_resourceUpdateAfterCallStarted() { public void resolved_resourceUpdatedBeforeCallStarted() { InternalConfigSelector configSelector = resolveToClusters(); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate( Arrays.asList( @@ -949,17 +1120,17 @@ public void resolved_resourceUpdatedBeforeCallStarted() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( "another-cluster", Collections.emptyList(), - TimeUnit.SECONDS.toNanos(20L), null), + TimeUnit.SECONDS.toNanos(20L), null, false), ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()))); // Two consecutive service config updates: one for removing clcuster1, // one for adding "another=cluster". - verify(mockListener, times(2)).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(3)).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); assertServiceConfigForLoadBalancingConfig( Arrays.asList(cluster2, "another-cluster"), @@ -978,6 +1149,7 @@ public void resolved_raceBetweenCallAndRepeatedResourceUpdate() { assertCallSelectClusterResult(call1, configSelector, cluster1, 15.0); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate( Arrays.asList( @@ -985,16 +1157,16 @@ public void resolved_raceBetweenCallAndRepeatedResourceUpdate() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( "another-cluster", Collections.emptyList(), - TimeUnit.SECONDS.toNanos(20L), null), + TimeUnit.SECONDS.toNanos(20L), null, false), ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), - TimeUnit.SECONDS.toNanos(15L), null), + TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); assertServiceConfigForLoadBalancingConfig( Arrays.asList(cluster1, cluster2, "another-cluster"), @@ -1006,15 +1178,15 @@ public void resolved_raceBetweenCallAndRepeatedResourceUpdate() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( "another-cluster", Collections.emptyList(), - TimeUnit.SECONDS.toNanos(15L), null), + TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), - TimeUnit.SECONDS.toNanos(15L), null), + TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()))); - verifyNoMoreInteractions(mockListener); // no cluster added/deleted + verify(mockListener, times(2)).onResult2(resolutionResultCaptor.capture()); assertCallSelectClusterResult(call1, configSelector, "another-cluster", 15.0); } @@ -1029,7 +1201,7 @@ public void resolved_raceBetweenClusterReleasedAndResourceUpdateAddBackAgain() { RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()))); xdsClient.deliverLdsUpdate( Arrays.asList( @@ -1037,16 +1209,22 @@ public void resolved_raceBetweenClusterReleasedAndResourceUpdateAddBackAgain() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()))); testCall.deliverErrorStatus(); - verifyNoMoreInteractions(mockListener); + verify(mockListener, times(3)).onResult2(resolutionResultCaptor.capture()); + assertServiceConfigForLoadBalancingConfig( + Arrays.asList(cluster1, cluster2), resolutionResultCaptor.getAllValues().get(1)); + assertServiceConfigForLoadBalancingConfig( + Arrays.asList(cluster1, cluster2), resolutionResultCaptor.getAllValues().get(2)); + assertServiceConfigForLoadBalancingConfig( + Arrays.asList(cluster1, cluster2), resolutionResultCaptor.getAllValues().get(3)); } @SuppressWarnings("unchecked") @@ -1067,19 +1245,33 @@ public void resolved_simpleCallSucceeds_routeToWeightedCluster() { cluster2, 80, ImmutableMap.of())), Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), - null), + null, false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); assertServiceConfigForLoadBalancingConfig( Arrays.asList(cluster1, cluster2), (Map) result.getServiceConfig().getConfig()); - assertThat(result.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL)).isNotNull(); + assertThat(result.getAttributes().get(XdsAttributes.XDS_CLIENT)).isNotNull(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); assertCallSelectClusterResult(call1, configSelector, cluster2, 20.0); assertCallSelectClusterResult(call1, configSelector, cluster1, 20.0); } + /** Creates and delivers both CDS and EDS updates for the given clusters. */ + private static void createAndDeliverClusterUpdates( + FakeXdsClient xdsClient, String... clusterNames) { + for (String clusterName : clusterNames) { + CdsUpdate.Builder forEds = CdsUpdate + .forEds(clusterName, clusterName, null, null, null, null, false, null) + .roundRobinLbPolicy(); + xdsClient.deliverCdsUpdate(clusterName, forEds.build()); + EdsUpdate edsUpdate = new EdsUpdate(clusterName, + XdsTestUtils.createMinimalLbEndpointsMap("127.0.0.3"), Collections.emptyList()); + xdsClient.deliverEdsUpdate(clusterName, edsUpdate); + } + } + @Test public void resolved_simpleCallSucceeds_routeToRls() { when(mockRandom.nextInt(anyInt())).thenReturn(90, 10); @@ -1096,11 +1288,11 @@ public void resolved_simpleCallSucceeds_routeToRls() { ImmutableMap.of("lookupService", "rls-cbt.googleapis.com"))), Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), - null), + null, false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); @SuppressWarnings("unchecked") Map resultServiceConfig = (Map) result.getServiceConfig().getConfig(); List> rawLbConfigs = @@ -1113,7 +1305,7 @@ public void resolved_simpleCallSucceeds_routeToRls() { "routeLookupConfig", ImmutableMap.of("lookupService", "rls-cbt.googleapis.com"), "childPolicy", - ImmutableList.of(ImmutableMap.of("cds_experimental", ImmutableMap.of())), + ImmutableList.of(ImmutableMap.of("cds_experimental", ImmutableMap.of("is_dynamic", true))), "childPolicyConfigTargetFieldName", "cluster"); Map expectedClusterManagerLbConfig = ImmutableMap.of( @@ -1125,7 +1317,7 @@ public void resolved_simpleCallSucceeds_routeToRls() { ImmutableList.of(ImmutableMap.of("rls_experimental", expectedRlsLbConfig))))); assertThat(clusterManagerLbConfig).isEqualTo(expectedClusterManagerLbConfig); - assertThat(result.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL)).isNotNull(); + assertThat(result.getAttributes().get(XdsAttributes.XDS_CLIENT)).isNotNull(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); assertCallSelectRlsPluginResult( call1, configSelector, "rls-plugin-foo", 20.0); @@ -1144,9 +1336,9 @@ public void resolved_simpleCallSucceeds_routeToRls() { Collections.emptyList(), // changed TimeUnit.SECONDS.toNanos(30L), - null), + null, false), ImmutableMap.of()))); - verify(mockListener, times(2)).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(2)).onResult2(resolutionResultCaptor.capture()); ResolutionResult result2 = resolutionResultCaptor.getValue(); @SuppressWarnings("unchecked") Map resultServiceConfig2 = (Map) result2.getServiceConfig().getConfig(); @@ -1160,7 +1352,7 @@ public void resolved_simpleCallSucceeds_routeToRls() { "routeLookupConfig", ImmutableMap.of("lookupService", "rls-cbt-2.googleapis.com"), "childPolicy", - ImmutableList.of(ImmutableMap.of("cds_experimental", ImmutableMap.of())), + ImmutableList.of(ImmutableMap.of("cds_experimental", ImmutableMap.of("is_dynamic", true))), "childPolicyConfigTargetFieldName", "cluster"); Map expectedClusterManagerLbConfig2 = ImmutableMap.of( @@ -1177,11 +1369,379 @@ public void resolved_simpleCallSucceeds_routeToRls() { call1, configSelector2, "rls-plugin-foo", 30.0); } + // Begin filter state tests. + + /** + * Verifies the lifecycle of HCM filter instances across LDS updates. + * + *

Filter instances: + * 1. Must have one unique instance per HCM filter name. + * 2. Must be reused when an LDS update with HCM contains a filter with the same name. + * 3. Must be shutdown (closed) when an HCM in a LDS update doesn't a filter with the same name. + */ + @Test + public void filterState_survivesLds() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + VirtualHost vhost = filterStateTestVhost(); + + // LDS 1. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + // Verify that StatefulFilter with different filter names result in different Filter instances. + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + assertThat(lds1Filter1).isNotSameInstanceAs(lds1Filter2); + // Redundant check just in case StatefulFilter synchronization is broken. + assertThat(lds1Filter1.idx).isEqualTo(0); + assertThat(lds1Filter2.idx).isEqualTo(1); + + // LDS 2: filter configs with the same names. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds2Snapshot = statefulFilterProvider.getAllInstances(); + // Filter names hasn't changed, so expecting no new StatefulFilter instances. + assertWithMessage("LDS 2: Expected Filter instances to be reused across LDS updates") + .that(lds2Snapshot).isEqualTo(lds1Snapshot); + + // LDS 3: Filter "STATEFUL_2" removed. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1)); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds3Snapshot = statefulFilterProvider.getAllInstances(); + // Again, no new StatefulFilter instances should be created. + assertWithMessage("LDS 3: Expected Filter instances to be reused across LDS updates") + .that(lds3Snapshot).isEqualTo(lds1Snapshot); + // Verify the shutdown state. + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertWithMessage("LDS 3: Expected %s to be shut down", lds1Filter2) + .that(lds1Filter2.isShutdown()).isTrue(); + + // LDS 4: Filter "STATEFUL_2" added back. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds4Snapshot = statefulFilterProvider.getAllInstances(); + // Filter "STATEFUL_2" should be treated as any other new filter name in an LDS update: + // a new instance should be created. + assertWithMessage("LDS 4: Expected a new filter instance for %s", STATEFUL_2) + .that(lds4Snapshot).hasSize(3); + StatefulFilter lds4Filter2 = lds4Snapshot.get(2); + assertThat(lds4Filter2.idx).isEqualTo(2); + assertThat(lds4Filter2).isNotSameInstanceAs(lds1Filter2); + assertThat(lds4Snapshot).containsAtLeastElementsIn(lds1Snapshot); + // Verify the shutdown state. + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertThat(lds1Filter2.isShutdown()).isTrue(); + assertThat(lds4Filter2.isShutdown()).isFalse(); + } + + /** + * Verifies the lifecycle of HCM filter instances across RDS updates. + * + *

Filter instances: + * 1. Must have instantiated by the initial LDS/RDS. + * 2. Must be reused by all subsequent RDS updates. + * 3. Must be not shutdown (closed) by valid RDS updates. + */ + @Test + public void filterState_survivesRds() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + + // LDS 1. + xdsClient.deliverLdsUpdateForRdsNameWithFilters(RDS_RESOURCE_NAME, + filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + // RDS 1. + VirtualHost vhost1 = filterStateTestVhost(); + xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, vhost1); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + // Initial RDS update should not generate Filter instances. + ImmutableList rds1Snapshot = statefulFilterProvider.getAllInstances(); + // Verify that StatefulFilter with different filter names result in different Filter instances. + assertWithMessage("RDS 1: expected to create filter instances").that(rds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = rds1Snapshot.get(0); + StatefulFilter lds1Filter2 = rds1Snapshot.get(1); + assertThat(lds1Filter1).isNotSameInstanceAs(lds1Filter2); + + // RDS 2: exactly the same as RDS 1. + xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, vhost1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList rds2Snapshot = statefulFilterProvider.getAllInstances(); + // Neither should any subsequent RDS updates. + assertWithMessage("RDS 2: Expected Filter instances to be reused across RDS route updates") + .that(rds2Snapshot).isEqualTo(rds1Snapshot); + + // RDS 3: Contains a per-route override for STATEFUL_1. + VirtualHost vhost3 = filterStateTestVhost(ImmutableMap.of( + STATEFUL_1, new StatefulFilter.Config("RDS3") + )); + xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, vhost3); + assertClusterResolutionResult(call1, cluster1); + ImmutableList rds3Snapshot = statefulFilterProvider.getAllInstances(); + // As with any other Route update, typed_per_filter_config overrides should not result in + // creating new filter instances. + assertWithMessage("RDS 3: Expected Filter instances to be reused on per-route filter overrides") + .that(rds3Snapshot).isEqualTo(rds1Snapshot); + } + + /** + * Verifies a special case where an existing filter is has a different typeUrl in a subsequent + * LDS update. + * + *

Expectations: + * 1. The old filter instance must be shutdown. + * 2. A new filter instance must be created for the new filter with different typeUrl. + */ + @Test + public void filterState_specialCase_sameNameDifferentTypeUrl() { + // Prepare filter registry with StatefulFilter of different typeUrl. + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + String altTypeUrl = "type.googleapis.com/grpc.test.AltStatefulFilter"; + StatefulFilter.Provider altStatefulFilterProvider = new StatefulFilter.Provider(altTypeUrl); + FilterRegistry filterRegistry = FilterRegistry.newRegistry() + .register(statefulFilterProvider, altStatefulFilterProvider, ROUTER_FILTER_PROVIDER); + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, + syncContext, scheduler, xdsClientPoolFactory, mockRandom, filterRegistry, rawBootstrap, + metricRecorder, nameResolverArgs); + resolver.start(mockListener); + + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + VirtualHost vhost = filterStateTestVhost(); + + // LDS 1. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + ImmutableList lds1SnapshotAlt = altStatefulFilterProvider.getAllInstances(); + // Verify that StatefulFilter with different filter names result in different Filter instances. + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + assertThat(lds1Filter1).isNotSameInstanceAs(lds1Filter2); + // Nothing in the alternative provider. + assertThat(lds1SnapshotAlt).isEmpty(); + + // LDS 2: Filter STATEFUL_2 present, but with a different typeUrl: altTypeUrl. + ImmutableList filterConfigs = ImmutableList.of( + new NamedFilterConfig(STATEFUL_1, new StatefulFilter.Config(STATEFUL_1)), + new NamedFilterConfig(STATEFUL_2, new StatefulFilter.Config(STATEFUL_2, altTypeUrl)), + new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG) + ); + xdsClient.deliverLdsUpdateWithFilters(vhost, filterConfigs); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds2Snapshot = statefulFilterProvider.getAllInstances(); + ImmutableList lds2SnapshotAlt = altStatefulFilterProvider.getAllInstances(); + // Filter "STATEFUL_2" has different typeUrl, and should be treated as a new filter. + // No changes in the snapshot of normal stateful filters. + assertWithMessage("LDS 2: expected a new filter instance of different type") + .that(lds2Snapshot).isEqualTo(lds1Snapshot); + // A new filter instance is created by altStatefulFilterProvider. + assertWithMessage("LDS 2: expected a new filter instance for type %s", altTypeUrl) + .that(lds2SnapshotAlt).hasSize(1); + StatefulFilter lds2Filter2Alt = lds2SnapshotAlt.get(0); + assertThat(lds2Filter2Alt).isNotSameInstanceAs(lds1Filter2); + // Verify the shutdown state. + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertThat(lds1Filter2.isShutdown()).isTrue(); + assertThat(lds2Filter2Alt.isShutdown()).isFalse(); + } + + /** + * Verifies that all filter instances are shutdown (closed) on LDS resource not found. + */ + @Test + public void filterState_shutdown_onLdsNotFound() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + VirtualHost vhost = filterStateTestVhost(); + + // LDS 1. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + + // LDS 2: resource not found. + reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); + xdsClient.deliverLdsResourceNotFound(); + assertEmptyResolutionResult(expectedLdsResourceName); + // Verify shutdown. + assertThat(lds1Filter1.isShutdown()).isTrue(); + assertThat(lds1Filter2.isShutdown()).isTrue(); + } + + @Test + public void filterState_noShutdown_onLdsDeletion() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + VirtualHost vhost = filterStateTestVhost(); + + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + + // LDS 2: Deliver a resource deletion, which is now an ambient error. + reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); + xdsClient.deliverLdsResourceDeletion(); + + // With an ambient error, no new resolution should happen. + verify(mockListener, never()).onResult2(any()); + + // Verify that the filters are NOT shut down. + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertThat(lds1Filter2.isShutdown()).isFalse(); + } + + /** + * Verifies that all filter instances are shutdown (closed) on LDS ResourceWatcher shutdown. + */ + @Test + public void filterState_shutdown_onResolverShutdown() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + VirtualHost vhost = filterStateTestVhost(); + + // LDS 1. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + + // Shutdown. + resolver.shutdown(); + resolver = null; // no need to shutdown again in the teardown. + // Verify shutdown. + assertThat(lds1Filter1.isShutdown()).isTrue(); + assertThat(lds1Filter2.isShutdown()).isTrue(); + } + + /** + * Verifies that all filter instances are shutdown (closed) on RDS resource not found. + */ + @Test + public void filterState_shutdown_onRdsNotFound() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + xdsClient.deliverLdsUpdateForRdsNameWithFilters( + RDS_RESOURCE_NAME, + filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + xdsClient.deliverRdsUpdate( + RDS_RESOURCE_NAME, + Collections.singletonList(filterStateTestVhost())); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + + ImmutableList rds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("RDS 1: Expected to create filter instances").that(rds1Snapshot).hasSize(2); + StatefulFilter rds1Filter1 = rds1Snapshot.get(0); + StatefulFilter rds1Filter2 = rds1Snapshot.get(1); + assertThat(rds1Filter1.isShutdown()).isFalse(); + assertThat(rds1Filter2.isShutdown()).isFalse(); + + reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); + xdsClient.deliverRdsResourceNotFound(RDS_RESOURCE_NAME); + + assertEmptyResolutionResult(RDS_RESOURCE_NAME); + assertThat(rds1Filter1.isShutdown()).isTrue(); + assertThat(rds1Filter2.isShutdown()).isTrue(); + } + + @Test + public void filterState_noShutdown_onRdsAmbientError() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + + // LDS 1. + xdsClient.deliverLdsUpdateForRdsNameWithFilters(RDS_RESOURCE_NAME, + filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + // RDS 1: Standard vhost with a route. + xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, filterStateTestVhost()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList rds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("RDS 1: expected to create filter instances").that(rds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = rds1Snapshot.get(0); + StatefulFilter lds1Filter2 = rds1Snapshot.get(1); + + // RDS 2: RDS_RESOURCE_NAME not found. + reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); + xdsClient.deliverRdsAmbientError(RDS_RESOURCE_NAME, Status.NOT_FOUND); + verify(mockListener, never()).onResult2(any()); + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertThat(lds1Filter2.isShutdown()).isFalse(); + } + + private StatefulFilter.Provider filterStateTestSetupResolver() { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = FilterRegistry.newRegistry() + .register(statefulFilterProvider, ROUTER_FILTER_PROVIDER); + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, + syncContext, scheduler, xdsClientPoolFactory, mockRandom, filterRegistry, rawBootstrap, + metricRecorder, nameResolverArgs); + resolver.start(mockListener); + return statefulFilterProvider; + } + + private ImmutableList filterStateTestConfigs(String... names) { + ImmutableList.Builder result = ImmutableList.builder(); + for (String name : names) { + result.add(new NamedFilterConfig(name, new StatefulFilter.Config(name))); + } + result.add(new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG)); + return result.build(); + } + + private Route filterStateTestRoute(ImmutableMap perRouteOverrides) { + // Standard basic route for filterState tests. + return Route.forAction( + RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), + RouteAction.forCluster(cluster1, NO_HASH_POLICIES, null, null, true), + perRouteOverrides); + } + + private VirtualHost filterStateTestVhost() { + return filterStateTestVhost(NO_FILTER_OVERRIDES); + } + + private VirtualHost filterStateTestVhost(ImmutableMap perRouteOverrides) { + return VirtualHost.create( + "stateful-vhost", + ImmutableList.of(expectedLdsResourceName), + ImmutableList.of(filterStateTestRoute(perRouteOverrides)), + NO_FILTER_OVERRIDES); + } + + // End filter state tests. + @SuppressWarnings("unchecked") private void assertEmptyResolutionResult(String resource) { - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); assertThat((Map) result.getServiceConfig().getConfig()).isEmpty(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); Result configResult = configSelector.selectConfig( @@ -1190,6 +1750,13 @@ private void assertEmptyResolutionResult(String resource) { assertThat(configResult.getStatus().getDescription()).contains(resource); } + private void assertClusterResolutionResult(CallInfo call, String expectedCluster) { + verify(mockListener, atLeast(1)).onResult2(resolutionResultCaptor.capture()); + ResolutionResult result = resolutionResultCaptor.getValue(); + InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); + assertCallSelectClusterResult(call, configSelector, expectedCluster, null); + } + private void assertCallSelectClusterResult( CallInfo call, InternalConfigSelector configSelector, String expectedCluster, @Nullable Double expectedTimeoutSec) { @@ -1202,6 +1769,10 @@ private void assertCallSelectClusterResult( clientCall.start(new NoopClientCallListener<>(), new Metadata()); assertThat(testCall.callOptions.getOption(XdsNameResolver.CLUSTER_SELECTION_KEY)) .isEqualTo("cluster:" + expectedCluster); + XdsConfig xdsConfig = + testCall.callOptions.getOption(XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY); + assertThat(xdsConfig).isNotNull(); + assertThat(xdsConfig.getClusters()).containsKey(expectedCluster); // Without "cluster:" prefix @SuppressWarnings("unchecked") Map config = (Map) result.getConfig(); if (expectedTimeoutSec != null) { @@ -1250,24 +1821,31 @@ private InternalConfigSelector resolveToClusters() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); assertServiceConfigForLoadBalancingConfig( Arrays.asList(cluster1, cluster2), (Map) result.getServiceConfig().getConfig()); - assertThat(result.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL)).isNotNull(); - assertThat(result.getAttributes().get(InternalXdsAttributes.CALL_COUNTER_PROVIDER)).isNotNull(); + assertThat(result.getAttributes().get(XdsAttributes.XDS_CLIENT)).isNotNull(); + assertThat(result.getAttributes().get(XdsAttributes.CALL_COUNTER_PROVIDER)).isNotNull(); return result.getAttributes().get(InternalConfigSelector.KEY); } + private static void assertServiceConfigForLoadBalancingConfig( + List clusters, ResolutionResult result) { + @SuppressWarnings("unchecked") + Map config = (Map) result.getServiceConfig().getConfig(); + assertServiceConfigForLoadBalancingConfig(clusters, config); + } + /** * Verifies the raw service config contains an xDS load balancing config for the given clusters. */ @@ -1305,7 +1883,7 @@ public void generateServiceConfig_forClusterManagerLoadBalancingConfig() throws Route route1 = Route.forAction( RouteMatch.withPathExactOnly("HelloService/hi"), RouteAction.forCluster( - "cluster-foo", Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + "cluster-foo", Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); Route route2 = Route.forAction( RouteMatch.withPathExactOnly("HelloService/hello"), @@ -1315,7 +1893,7 @@ public void generateServiceConfig_forClusterManagerLoadBalancingConfig() throws ClusterWeight.create("cluster-baz", 50, ImmutableMap.of())), ImmutableList.of(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()); Map rlsConfig = ImmutableMap.of("lookupService", "rls.bigtable.google.com"); Route route3 = Route.forAction( @@ -1324,7 +1902,7 @@ public void generateServiceConfig_forClusterManagerLoadBalancingConfig() throws NamedPluginConfig.create("plugin-foo", RlsPluginConfig.create(rlsConfig)), Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), - null), + null, false), ImmutableMap.of()); resolver.start(mockListener); @@ -1335,8 +1913,9 @@ public void generateServiceConfig_forClusterManagerLoadBalancingConfig() throws ImmutableList.of(route1, route2, route3), ImmutableMap.of()); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); + createAndDeliverClusterUpdates(xdsClient, "cluster-foo", "cluster-bar", "cluster-baz"); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); String expectedServiceConfigJson = "{\n" + " \"loadBalancingConfig\": [{\n" @@ -1370,7 +1949,9 @@ public void generateServiceConfig_forClusterManagerLoadBalancingConfig() throws + " \"lookupService\": \"rls.bigtable.google.com\"\n" + " },\n" + " \"childPolicy\": [\n" - + " {\"cds_experimental\": {}}\n" + + " {\"cds_experimental\": {\n" + + " \"is_dynamic\": true\n" + + " }}\n" + " ],\n" + " \"childPolicyConfigTargetFieldName\": \"cluster\"\n" + " }\n" @@ -1431,7 +2012,6 @@ public void generateServiceConfig_forPerMethodConfig() throws IOException { assertThat(XdsNameResolver.generateServiceConfigWithMethodConfig(null, retryPolicy)) .isEqualTo(expectedServiceConfig); - // timeout and retry expectedServiceConfigJson = "{\n" + " \"methodConfig\": [{\n" @@ -1520,7 +2100,7 @@ public void resolved_faultAbortInLdsUpdate() { FaultAbort.forHeader(FaultConfig.FractionalPercent.perHundred(70)), null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); // no header abort key provided in metadata, rpc should succeed @@ -1559,7 +2139,7 @@ public void resolved_faultAbortInLdsUpdate() { FaultAbort.forHeader(FaultConfig.FractionalPercent.perMillion(600_000)), null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(2)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1575,7 +2155,7 @@ public void resolved_faultAbortInLdsUpdate() { FaultAbort.forHeader(FaultConfig.FractionalPercent.perMillion(0)), null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(3)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1590,7 +2170,7 @@ public void resolved_faultAbortInLdsUpdate() { FaultConfig.FractionalPercent.perMillion(600_000)), null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(4)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1608,7 +2188,7 @@ public void resolved_faultAbortInLdsUpdate() { FaultConfig.FractionalPercent.perMillion(400_000)), null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(5)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1626,7 +2206,7 @@ public void resolved_faultDelayInLdsUpdate() { FaultConfig httpFilterFaultConfig = FaultConfig.create( FaultDelay.forHeader(FaultConfig.FractionalPercent.perHundred(70)), null, null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); // no header delay key provided in metadata, rpc should succeed immediately @@ -1643,7 +2223,7 @@ public void resolved_faultDelayInLdsUpdate() { httpFilterFaultConfig = FaultConfig.create( FaultDelay.forHeader(FaultConfig.FractionalPercent.perMillion(600_000)), null, null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(2)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1654,7 +2234,7 @@ public void resolved_faultDelayInLdsUpdate() { httpFilterFaultConfig = FaultConfig.create( FaultDelay.forHeader(FaultConfig.FractionalPercent.perMillion(0)), null, null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(3)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1667,7 +2247,7 @@ public void resolved_faultDelayInLdsUpdate() { null, null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(4)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1680,7 +2260,7 @@ public void resolved_faultDelayInLdsUpdate() { null, null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(5)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1699,7 +2279,7 @@ public void resolved_faultDelayWithMaxActiveStreamsInLdsUpdate() { null, /* maxActiveFaults= */ 1); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); @@ -1729,7 +2309,7 @@ public void resolved_faultDelayInLdsUpdate_callWithEarlyDeadline() { null, null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); @@ -1745,7 +2325,7 @@ public long nanoTime() { assertThat(testCall).isNull(); verifyRpcDelayedThenAborted(observer, 4000L, Status.DEADLINE_EXCEEDED.withDescription( "Deadline exceeded after up to 5000 ns of fault-injected delay:" - + " Deadline CallOptions will be exceeded in 0.000004000s. ")); + + " Deadline CallOptions was exceeded after 0.000004000s")); } @Test @@ -1761,7 +2341,7 @@ public void resolved_faultAbortAndDelayInLdsUpdateInLdsUpdate() { FaultConfig.FractionalPercent.perMillion(1000_000)), null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), @@ -1790,7 +2370,7 @@ public void resolved_faultConfigOverrideInLdsUpdate() { null); xdsClient.deliverLdsUpdateWithFaultInjection( cluster1, httpFilterFaultConfig, virtualHostFaultConfig, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), @@ -1805,7 +2385,7 @@ public void resolved_faultConfigOverrideInLdsUpdate() { null); xdsClient.deliverLdsUpdateWithFaultInjection( cluster1, httpFilterFaultConfig, virtualHostFaultConfig, routeFaultConfig, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(2)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1822,7 +2402,7 @@ public void resolved_faultConfigOverrideInLdsUpdate() { xdsClient.deliverLdsUpdateWithFaultInjection( cluster1, httpFilterFaultConfig, virtualHostFaultConfig, routeFaultConfig, weightedClusterFaultConfig); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(3)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1851,7 +2431,7 @@ public void resolved_faultConfigOverrideInLdsAndInRdsUpdate() { FaultAbort.forStatus(Status.UNKNOWN, FaultConfig.FractionalPercent.perMillion(1000_000)), null); xdsClient.deliverRdsUpdateWithFaultInjection(RDS_RESOURCE_NAME, null, routeFaultConfig, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), @@ -1914,9 +2494,7 @@ private PickSubchannelArgs newPickSubchannelArgs( private final class FakeXdsClientPoolFactory implements XdsClientPoolFactory { Set targets = new HashSet<>(); - - @Override - public void setBootstrapOverride(Map bootstrap) {} + XdsClient xdsClient = new FakeXdsClient(); @Override @Nullable @@ -1925,12 +2503,13 @@ public ObjectPool get(String target) { } @Override - public ObjectPool getOrCreate(String target) throws XdsInitializationException { + public ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder) { targets.add(target); return new ObjectPool() { @Override public XdsClient getObject() { - return new FakeXdsClient(); + return xdsClient; } @Override @@ -1954,9 +2533,10 @@ public List getTargets() { private class FakeXdsClient extends XdsClient { // Should never be subscribing to more than one LDS and RDS resource at any point of time. private String ldsResource; // should always be AUTHORITY - private String rdsResource; private ResourceWatcher ldsWatcher; - private ResourceWatcher rdsWatcher; + private final Map>> rdsWatchers = new HashMap<>(); + private final Map>> cdsWatchers = new HashMap<>(); + private final Map>> edsWatchers = new HashMap<>(); @Override public BootstrapInfo getBootstrapInfo() { @@ -1979,15 +2559,22 @@ public void watchXdsResource(XdsResourceType resou ldsWatcher = (ResourceWatcher) watcher; break; case "RDS": - assertThat(rdsResource).isNull(); - assertThat(rdsWatcher).isNull(); - rdsResource = resourceName; - rdsWatcher = (ResourceWatcher) watcher; + rdsWatchers.computeIfAbsent(resourceName, k -> new ArrayList<>()) + .add((ResourceWatcher) watcher); + break; + case "CDS": + cdsWatchers.computeIfAbsent(resourceName, k -> new ArrayList<>()) + .add((ResourceWatcher) watcher); + break; + case "EDS": + edsWatchers.computeIfAbsent(resourceName, k -> new ArrayList<>()) + .add((ResourceWatcher) watcher); break; default: } } + @SuppressWarnings("unchecked") @Override public void cancelXdsResourceWatch(XdsResourceType type, String resourceName, @@ -2001,19 +2588,57 @@ public void cancelXdsResourceWatch(XdsResourceType ldsWatcher = null; break; case "RDS": - assertThat(rdsResource).isNotNull(); - assertThat(rdsWatcher).isNotNull(); - rdsResource = null; - rdsWatcher = null; + assertThat(rdsWatchers).containsKey(resourceName); + assertThat(rdsWatchers.get(resourceName)).contains(watcher); + rdsWatchers.get(resourceName).remove((ResourceWatcher) watcher); + if (rdsWatchers.get(resourceName).isEmpty()) { + rdsWatchers.remove(resourceName); + } + break; + case "CDS": + assertThat(cdsWatchers).containsKey(resourceName); + assertThat(cdsWatchers.get(resourceName)).contains(watcher); + cdsWatchers.get(resourceName).remove((ResourceWatcher) watcher); + break; + case "EDS": + assertThat(edsWatchers).containsKey(resourceName); + assertThat(edsWatchers.get(resourceName)).contains(watcher); + edsWatchers.get(resourceName).remove((ResourceWatcher) watcher); break; default: } } + void deliverRdsAmbientError(String resourceName, Status status) { + if (!rdsWatchers.containsKey(resourceName)) { + return; + } + syncContext.execute(() -> { + List> resourceWatchers = + ImmutableList.copyOf(rdsWatchers.get(resourceName)); + resourceWatchers.forEach(w -> w.onAmbientError(status)); + }); + } + + void deliverLdsUpdateOnly(long httpMaxStreamDurationNano, List virtualHosts) { + syncContext.execute(() -> { + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( + httpMaxStreamDurationNano, virtualHosts, null)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); + }); + } + void deliverLdsUpdate(long httpMaxStreamDurationNano, List virtualHosts) { + List clusterNames = new ArrayList<>(); + for (VirtualHost vh : virtualHosts) { + clusterNames.addAll(getClusterNames(vh.routes())); + } + syncContext.execute(() -> { - ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( - httpMaxStreamDurationNano, virtualHosts, null))); + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( + httpMaxStreamDurationNano, virtualHosts, null)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); + createAndDeliverClusterUpdates(this, clusterNames.toArray(new String[0])); }); } @@ -2022,9 +2647,23 @@ void deliverLdsUpdate(final List routes) { VirtualHost.create( "virtual-host", Collections.singletonList(expectedLdsResourceName), routes, ImmutableMap.of()); + List clusterNames = getClusterNames(routes); + + syncContext.execute(() -> { + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), null)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); + if (!clusterNames.isEmpty()) { + createAndDeliverClusterUpdates(this, clusterNames.toArray(new String[0])); + } + }); + } + + void deliverLdsUpdateWithFilters(VirtualHost vhost, List filterConfigs) { syncContext.execute(() -> { - ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( - 0L, Collections.singletonList(virtualHost), null))); + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(vhost), filterConfigs)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); }); } @@ -2058,7 +2697,8 @@ void deliverLdsUpdateWithFaultInjection( Collections.singletonList(clusterWeight), Collections.emptyList(), null, - null), + null, + false), overrideConfig); overrideConfig = virtualHostFaultConfig == null ? ImmutableMap.of() @@ -2070,8 +2710,10 @@ void deliverLdsUpdateWithFaultInjection( Collections.singletonList(route), overrideConfig); syncContext.execute(() -> { - ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( - 0L, Collections.singletonList(virtualHost), filterChain))); + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), filterChain)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); + createAndDeliverClusterUpdates(this, cluster); }); } @@ -2085,30 +2727,70 @@ void deliverLdsUpdateForRdsNameWithFaultInjection( new NamedFilterConfig(FAULT_FILTER_INSTANCE_NAME, httpFilterFaultConfig), new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG)); syncContext.execute(() -> { - ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forRdsName( - 0L, rdsName, filterChain))); + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forRdsName( + 0L, rdsName, filterChain)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); }); } void deliverLdsUpdateForRdsName(String rdsName) { + deliverLdsUpdateForRdsNameWithFilters(rdsName, null); + } + + void deliverLdsUpdateForRdsNameWithFilters( + String rdsName, + @Nullable List filterConfigs) { + syncContext.execute(() -> { + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forRdsName( + 0, rdsName, filterConfigs)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); + }); + } + + void deliverLdsResourceDeletion() { + Status status = Status.NOT_FOUND.withDescription( + "Resource not found: " + expectedLdsResourceName); syncContext.execute(() -> { - ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forRdsName( - 0, rdsName, null))); + ldsWatcher.onAmbientError(status); }); } void deliverLdsResourceNotFound() { + Status notFoundStatus = Status.UNAVAILABLE.withDescription( + "Resource not found: " + expectedLdsResourceName); syncContext.execute(() -> { - ldsWatcher.onResourceDoesNotExist(expectedLdsResourceName); + if (ldsWatcher != null) { + ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); + } }); } + private List getClusterNames(List routes) { + List clusterNames = new ArrayList<>(); + for (Route r : routes) { + if (r.routeAction() == null) { + continue; + } + String cluster = r.routeAction().cluster(); + if (cluster != null) { + clusterNames.add(cluster); + } else { + List weightedClusters = r.routeAction().weightedClusters(); + if (weightedClusters == null) { + continue; + } + for (ClusterWeight wc : weightedClusters) { + clusterNames.add(wc.name()); + } + } + } + + return clusterNames; + } + void deliverRdsUpdateWithFaultInjection( String resourceName, @Nullable FaultConfig virtualHostFaultConfig, @Nullable FaultConfig routFaultConfig, @Nullable FaultConfig weightedClusterFaultConfig) { - if (!resourceName.equals(rdsResource)) { - return; - } ImmutableMap overrideConfig = weightedClusterFaultConfig == null ? ImmutableMap.of() : ImmutableMap.of( @@ -2125,7 +2807,8 @@ void deliverRdsUpdateWithFaultInjection( Collections.singletonList(clusterWeight), Collections.emptyList(), null, - null), + null, + false), overrideConfig); overrideConfig = virtualHostFaultConfig == null ? ImmutableMap.of() @@ -2136,40 +2819,78 @@ void deliverRdsUpdateWithFaultInjection( Collections.singletonList(expectedLdsResourceName), Collections.singletonList(route), overrideConfig); - syncContext.execute(() -> { - rdsWatcher.onChanged(new RdsUpdate(Collections.singletonList(virtualHost))); - }); + deliverRdsUpdate(resourceName, virtualHost); + createAndDeliverClusterUpdates(this, cluster1); } void deliverRdsUpdate(String resourceName, List virtualHosts) { - if (!resourceName.equals(rdsResource)) { + if (!rdsWatchers.containsKey(resourceName)) { return; } syncContext.execute(() -> { - rdsWatcher.onChanged(new RdsUpdate(virtualHosts)); + RdsUpdate update = new RdsUpdate(virtualHosts); + List> resourceWatchers = + ImmutableList.copyOf(rdsWatchers.get(resourceName)); + resourceWatchers.forEach(w -> w.onResourceChanged(StatusOr.fromValue(update))); }); } + void deliverRdsUpdate(String resourceName, VirtualHost virtualHost) { + deliverRdsUpdate(resourceName, ImmutableList.of(virtualHost)); + } + void deliverRdsResourceNotFound(String resourceName) { - if (!resourceName.equals(rdsResource)) { + if (!rdsWatchers.containsKey(resourceName)) { + return; + } + syncContext.execute(() -> { + List> resourceWatchers = + ImmutableList.copyOf(rdsWatchers.get(resourceName)); + Status status = Status.UNAVAILABLE.withDescription("Resource not found: " + resourceName); + resourceWatchers.forEach(w -> w.onResourceChanged(StatusOr.fromStatus(status))); + }); + } + + private void deliverCdsUpdate(String clusterName, CdsUpdate update) { + if (!cdsWatchers.containsKey(clusterName)) { return; } syncContext.execute(() -> { - rdsWatcher.onResourceDoesNotExist(rdsResource); + List> resourceWatchers = + ImmutableList.copyOf(cdsWatchers.get(clusterName)); + resourceWatchers.forEach(w -> w.onResourceChanged(StatusOr.fromValue(update))); + }); + } + + private void deliverEdsUpdate(String name, EdsUpdate update) { + syncContext.execute(() -> { + if (!edsWatchers.containsKey(name)) { + return; + } + List> resourceWatchers = + ImmutableList.copyOf(edsWatchers.get(name)); + resourceWatchers.forEach(w -> w.onResourceChanged(StatusOr.fromValue(update))); }); } + void deliverError(final Status error) { if (ldsWatcher != null) { syncContext.execute(() -> { - ldsWatcher.onError(error); - }); - } - if (rdsWatcher != null) { - syncContext.execute(() -> { - rdsWatcher.onError(error); + ldsWatcher.onResourceChanged(StatusOr.fromStatus(error)); }); } + syncContext.execute(() -> { + List> rdsCopy = rdsWatchers.values().stream() + .flatMap(List::stream).collect(java.util.stream.Collectors.toList()); + List> cdsCopy = cdsWatchers.values().stream() + .flatMap(List::stream).collect(java.util.stream.Collectors.toList()); + List> edsCopy = edsWatchers.values().stream() + .flatMap(List::stream).collect(java.util.stream.Collectors.toList()); + rdsCopy.forEach(w -> w.onResourceChanged(StatusOr.fromStatus(error))); + cdsCopy.forEach(w -> w.onResourceChanged(StatusOr.fromStatus(error))); + edsCopy.forEach(w -> w.onResourceChanged(StatusOr.fromStatus(error))); + }); } } diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 2c349eec4af..c8ad9f1c670 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -24,13 +24,20 @@ import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_SPIFFE_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_SPIFFE_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SPIFFE_TRUST_MAP_1_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SPIFFE_TRUST_MAP_FILE; import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.SettableFuture; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.Grpc; @@ -43,6 +50,7 @@ import io.grpc.Server; import io.grpc.ServerCredentials; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.StatusRuntimeException; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; @@ -61,35 +69,60 @@ import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.internal.Matchers.HeaderMatcher; +import io.grpc.xds.internal.XdsInternalAttributes; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.SecurityProtocolNegotiators; import io.grpc.xds.internal.security.SslContextProviderSupplier; import io.grpc.xds.internal.security.TlsContextManagerImpl; +import io.grpc.xds.internal.security.certprovider.FileWatcherCertificateProviderProvider; import io.netty.handler.ssl.NotSslRecordException; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.net.Inet4Address; import java.net.InetSocketAddress; import java.net.URI; -import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.TrustManagerFactory; import org.junit.After; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** * Unit tests for {@link XdsChannelCredentials} and {@link XdsServerBuilder} for plaintext/TLS/mTLS * modes. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class XdsSecurityClientServerTest { + + private static final String SNI_IN_UTC = "waterzooi.test.google.be"; + + @Parameter + public Boolean enableSpiffe; + private Boolean originalEnableSpiffe; @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); private int port; @@ -101,12 +134,37 @@ public class XdsSecurityClientServerTest { private FakeXdsClient xdsClient = new FakeXdsClient(); private FakeXdsClientPoolFactory fakePoolFactory = new FakeXdsClientPoolFactory(xdsClient); private static final String OVERRIDE_AUTHORITY = "foo.test.google.fr"; + private Attributes sslContextAttributes; + + @Parameters(name = "enableSpiffe={0}") + public static Collection data() { + return ImmutableList.of(true, false); + } + + @Before + public void setUp() throws IOException { + saveEnvironment(); + FileWatcherCertificateProviderProvider.enableSpiffe = enableSpiffe; + } + + private void saveEnvironment() { + originalEnableSpiffe = FileWatcherCertificateProviderProvider.enableSpiffe; + } @After - public void tearDown() { + public void tearDown() throws IOException { if (fakeNameResolverFactory != null) { NameResolverRegistry.getDefaultRegistry().deregister(fakeNameResolverFactory); } + FileWatcherCertificateProviderProvider.enableSpiffe = originalEnableSpiffe; + if (sslContextAttributes != null) { + SslContextProviderSupplier sslContextProviderSupplier = sslContextAttributes.get( + SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); + if (sslContextProviderSupplier != null) { + sslContextProviderSupplier.close(); + } + sslContextAttributes = null; + } } @Test @@ -133,30 +191,311 @@ public void nullFallbackCredentials_expectException() throws Exception { @Test public void tlsClientServer_noClientAuthentication() throws Exception { DownstreamTlsContext downstreamTlsContext = - setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false); + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); buildServerWithTlsContext(downstreamTlsContext); // for TLS, client only needs trustCa UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false); + CLIENT_KEY_FILE, CLIENT_PEM_FILE, null, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); } + /** + * Use system root ca cert for TLS channel - no mTLS. + * Uses common_tls_context.combined_validation_context in upstream_tls_context. + */ + @Test + public void tlsClientServer_useSystemRootCerts_noMtls_useCombinedValidationContext() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, SNI_IN_UTC, false, "", false, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + /** + * Use system root ca cert for TLS channel - no mTLS. + * Uses common_tls_context.validation_context in upstream_tls_context. + */ + @Test + public void tlsClientServer_useSystemRootCerts_noMtls_validationContext() throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa().toAbsolutePath(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, false, SNI_IN_UTC, false, null, false, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } finally { + Files.deleteIfExists(trustStoreFilePath.toAbsolutePath()); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_useSystemRootCerts_mtls() throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, true); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, SNI_IN_UTC, true, "", false, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_noAutoSniValidation_failureToMatchSubjAltNames() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, "server1.test.google.in", false, "", false, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + fail("Expected handshake failure exception"); + } catch (StatusRuntimeException e) { + assertThat(e.getCause()).isInstanceOf(SSLHandshakeException.class); + assertThat(e.getCause().getCause()).isInstanceOf(CertificateException.class); + assertThat(e.getCause().getCause().getMessage()).isEqualTo( + "Peer certificate SAN check failed"); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + + @Test + public void tlsClientServer_autoSniValidation_sniInUtc() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, + // SAN matcher in CommonValidationContext. Will be overridden by autoSniSanValidation + "server1.test.google.in", + false, + SNI_IN_UTC, + false, true); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_autoSniValidation_sniFromHostname() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, + // SAN matcher in CommonValidationContext. Will be overridden by autoSniSanValidation + "server1.test.google.in", + false, + "", + true, true); + + // TODO: Change this to foo.test.gooogle.fr that needs wildcard matching after + // https://github.com/grpc/grpc-java/pull/12345 is done + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY, + "waterzooi.test.google.be"); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_autoSniValidation_noSniApplicable_usesMatcherFromCmnVdnCtx() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, + // This is what will get used for the SAN validation since no SNI was used + "waterzooi.test.google.be", + false, + "", + false, true); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + /** + * Use system root ca cert for TLS channel - mTLS. + */ + @Test + public void tlsClientServer_useSystemRootCerts_requireClientAuth() throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa().toAbsolutePath(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, true); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, SNI_IN_UTC, false, "", false, false); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } finally { + Files.deleteIfExists(trustStoreFilePath.toAbsolutePath()); + clearTrustStoreSystemProperties(); + } + } + + private Path getCacertFilePathForTestCa() + throws IOException, KeyStoreException, CertificateException, NoSuchAlgorithmException { + KeyStore keystore = KeyStore.getInstance(KeyStore.getDefaultType()); + keystore.load(null, null); + InputStream caCertStream = getClass().getResource("/certs/ca.pem").openStream(); + keystore.setCertificateEntry("testca", CertificateFactory.getInstance("X.509") + .generateCertificate(caCertStream)); + caCertStream.close(); + File trustStoreFile = File.createTempFile("testca-truststore", "jks"); + FileOutputStream out = new FileOutputStream(trustStoreFile); + keystore.store(out, "changeit".toCharArray()); + out.close(); + return trustStoreFile.toPath(); + } + + @Test + public void tlsClientServer_Spiffe_noClientAuthentication() throws Exception { + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_SPIFFE_PEM_FILE, null, null, null, + null, null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + // for TLS, client only needs trustCa, so BAD certs don't matter + UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( + BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, SPIFFE_TRUST_MAP_FILE, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } + + @Test + public void tlsClientServer_Spiffe_noClientAuthentication_wrongServerCert() throws Exception { + if (!enableSpiffe) { + return; + } + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + // for TLS, client only needs trustCa, so BAD certs don't matter + UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( + BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, SPIFFE_TRUST_MAP_FILE, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + try { + unaryRpc("buddy", blockingStub); + fail("exception expected"); + } catch (StatusRuntimeException sre) { + assertThat(sre.getStatus().getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(sre.getCause().getCause().getMessage()) + .contains("Failed to extract SPIFFE ID from peer leaf certificate"); + } + } + @Test public void requireClientAuth_noClientCert_expectException() throws Exception { DownstreamTlsContext downstreamTlsContext = - setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, true, true); + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, true, true); buildServerWithTlsContext(downstreamTlsContext); // for TLS, client only uses trustCa UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false); + CLIENT_KEY_FILE, CLIENT_PEM_FILE, null, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -178,12 +517,12 @@ public void requireClientAuth_noClientCert_expectException() @Test public void noClientAuth_sendBadClientCert_passes() throws Exception { DownstreamTlsContext downstreamTlsContext = - setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false); + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); buildServerWithTlsContext(downstreamTlsContext); UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - BAD_CLIENT_KEY_FILE, - BAD_CLIENT_PEM_FILE, true); + BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, null, true); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -193,8 +532,7 @@ public void noClientAuth_sendBadClientCert_passes() throws Exception { @Test public void mtls_badClientCert_expectException() throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - BAD_CLIENT_KEY_FILE, - BAD_CLIENT_PEM_FILE, true); + BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, null, true); try { performMtlsTestAndGetListenerWatcher(upstreamTlsContext, null, null, null, null); fail("exception expected"); @@ -210,20 +548,58 @@ public void mtls_badClientCert_expectException() throws Exception { } } - /** mTLS - client auth enabled - using {@link XdsChannelCredentials} API. */ + /** mTLS with Spiffe Trust Bundle - client auth enabled - using {@link XdsChannelCredentials} + * API. */ + @Test + public void mtlsClientServer_Spiffe_withClientAuthentication_withXdsChannelCreds() + throws Exception { + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_SPIFFE_PEM_FILE, null, null, null, + null, SPIFFE_TRUST_MAP_1_FILE, true, true); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( + CLIENT_KEY_FILE, CLIENT_SPIFFE_PEM_FILE, SPIFFE_TRUST_MAP_1_FILE, true); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } + + @Test + public void mtlsClientServer_Spiffe_badClientCert_expectException() + throws Exception { + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_SPIFFE_PEM_FILE, null, null, null, + null, SPIFFE_TRUST_MAP_1_FILE, true, true); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( + CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, SPIFFE_TRUST_MAP_1_FILE, true); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + try { + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + fail("exception expected"); + } catch (StatusRuntimeException sre) { + assertThat(sre.getStatus().getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(sre.getMessage()).contains("ssl exception"); + } + } + @Test public void mtlsClientServer_withClientAuthentication_withXdsChannelCreds() throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true); + CLIENT_KEY_FILE, CLIENT_PEM_FILE, null, true); performMtlsTestAndGetListenerWatcher(upstreamTlsContext, null, null, null, null); } @Test public void tlsServer_plaintextClient_expectException() throws Exception { DownstreamTlsContext downstreamTlsContext = - setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false); + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); buildServerWithTlsContext(downstreamTlsContext); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = @@ -243,8 +619,7 @@ public void plaintextServer_tlsClient_expectException() throws Exception { // for TLS, client only needs trustCa UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false); + CLIENT_KEY_FILE, CLIENT_PEM_FILE, null, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -262,15 +637,14 @@ public void plaintextServer_tlsClient_expectException() throws Exception { public void mtlsClientServer_changeServerContext_expectException() throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true); + CLIENT_KEY_FILE, CLIENT_PEM_FILE, null, true); performMtlsTestAndGetListenerWatcher(upstreamTlsContext, "cert-instance-name2", BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContext( "cert-instance-name2", true, true); - EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0", + EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0:0", downstreamTlsContext, tlsContextManagerForServer); xdsClient.deliverLdsUpdate(LdsUpdate.forTcpListener(listener)); @@ -290,8 +664,8 @@ private void performMtlsTestAndGetListenerWatcher( String privateKey2, String cert2, String trustCa2) throws Exception { DownstreamTlsContext downstreamTlsContext = - setBootstrapInfoAndBuildDownstreamTlsContext(certInstanceName2, privateKey2, cert2, - trustCa2, true, true); + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, certInstanceName2, + privateKey2, cert2, trustCa2, null, true, false); buildServerWithFallbackServerCredentials( InsecureServerCredentials.create(), downstreamTlsContext); @@ -302,26 +676,58 @@ private void performMtlsTestAndGetListenerWatcher( } private DownstreamTlsContext setBootstrapInfoAndBuildDownstreamTlsContext( - String certInstanceName2, - String privateKey2, - String cert2, String trustCa2, boolean hasRootCert, boolean requireClientCertificate) { + String cert1, String certInstanceName2, String privateKey2, + String cert2, String trustCa2, String spiffeFile, + boolean hasRootCert, boolean requireClientCertificate) { bootstrapInfoForServer = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE, - SERVER_1_PEM_FILE, CA_PEM_FILE, certInstanceName2, privateKey2, cert2, trustCa2); + cert1, CA_PEM_FILE, certInstanceName2, privateKey2, cert2, trustCa2, spiffeFile); return CommonTlsContextTestsUtil.buildDownstreamTlsContext( "google_cloud_private_spiffe-server", hasRootCert, requireClientCertificate); } private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String clientKeyFile, - String clientPemFile, - boolean hasIdentityCert) { + String clientPemFile, String spiffeFile, boolean hasIdentityCert) { bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, - CA_PEM_FILE, null, null, null, null); + CA_PEM_FILE, null, null, null, null, spiffeFile); return CommonTlsContextTestsUtil .buildUpstreamTlsContext("google_cloud_private_spiffe-client", hasIdentityCert); } + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names + private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts( + String clientKeyFile, + String clientPemFile, + boolean useCombinedValidationContext, + String sanToMatch, + boolean isMtls, + String sniInUpstreamTlsContext, + boolean autoHostSni, boolean autoSniSanValidation) { + bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, + CA_PEM_FILE, null, null, null, null, null); + if (useCombinedValidationContext) { + return CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + isMtls ? "google_cloud_private_spiffe-client" : null, + isMtls ? "ROOT" : null, null, + null, null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .addMatchSubjectAltNames( + StringMatcher.newBuilder() + .setExact(sanToMatch)) + .build(), sniInUpstreamTlsContext, autoHostSni, autoSniSanValidation); + } + return CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + "google_cloud_private_spiffe-client", "ROOT", null, + null, null, CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .build()); + } + private void buildServerWithTlsContext(DownstreamTlsContext downstreamTlsContext) throws Exception { buildServerWithTlsContext(downstreamTlsContext, InsecureServerCredentials.create()); @@ -340,6 +746,7 @@ private void buildServerWithFallbackServerCredentials( ServerCredentials xdsCredentials = XdsServerCredentials.create(fallbackCredentials); XdsServerBuilder builder = XdsServerBuilder.forPort(0, xdsCredentials) .xdsClientPoolFactory(fakePoolFactory) + .overrideBootstrapForTest(XdsServerTestHelper.RAW_BOOTSTRAP) .addService(new SimpleServiceImpl()); buildServer(builder, downstreamTlsContext); } @@ -351,7 +758,7 @@ private void buildServer( tlsContextManagerForServer = new TlsContextManagerImpl(bootstrapInfoForServer); XdsServerWrapper xdsServer = (XdsServerWrapper) builder.build(); SettableFuture startFuture = startServerAsync(xdsServer); - EnvoyServerProtoData.Listener listener = buildListener("listener1", "10.1.2.3", + EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0:0", downstreamTlsContext, tlsContextManagerForServer); LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); xdsClient.deliverLdsUpdate(listenerUpdate); @@ -392,13 +799,25 @@ static EnvoyServerProtoData.Listener buildListener( "filter-chain-foo", filterChainMatch, httpConnectionManager, tlsContext, tlsContextManager); EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener.create( - name, address, ImmutableList.of(defaultFilterChain), null); + name, address, ImmutableList.of(defaultFilterChain), null, Protocol.TCP); return listener; } private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( - final UpstreamTlsContext upstreamTlsContext, String overrideAuthority) - throws URISyntaxException { + final UpstreamTlsContext upstreamTlsContext, String overrideAuthority) { + return getBlockingStub(upstreamTlsContext, overrideAuthority, overrideAuthority); + } + + // Two separate parameters for overrideAuthority and addrAttribute is for the SAN SNI validation + // test tlsClientServer_useSystemRootCerts_sni_san_validation_from_hostname that uses hostname + // passed for SNI. foo.test.google.fr is used for virtual host matching via authority but it + // can't be used for SNI in this testcase because foo.test.google.fr needs wildcard matching to + // match against *.test.google.fr in the certificate SNI, which isn't implemented yet + // (https://github.com/grpc/grpc-java/pull/12345 implements it) + // so use an exact match SAN such as waterzooi.test.google.be for SNI for this testcase. + private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( + final UpstreamTlsContext upstreamTlsContext, String overrideAuthority, + String addrNameAttribute) { ManagedChannelBuilder channelBuilder = Grpc.newChannelBuilder( "sectest://localhost:" + port, @@ -410,16 +829,18 @@ private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( InetSocketAddress socketAddress = new InetSocketAddress(Inet4Address.getLoopbackAddress(), port); tlsContextManagerForClient = new TlsContextManagerImpl(bootstrapInfoForClient); - Attributes attrs = - (upstreamTlsContext != null) - ? Attributes.newBuilder() - .set(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, - new SslContextProviderSupplier( - upstreamTlsContext, tlsContextManagerForClient)) - .build() - : Attributes.EMPTY; + Attributes.Builder sslContextAttributesBuilder = (upstreamTlsContext != null) + ? Attributes.newBuilder() + .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + new SslContextProviderSupplier( + upstreamTlsContext, tlsContextManagerForClient)) + : Attributes.newBuilder(); + if (addrNameAttribute != null) { + sslContextAttributesBuilder.set(XdsInternalAttributes.ATTR_ADDRESS_NAME, addrNameAttribute); + } + sslContextAttributes = sslContextAttributesBuilder.build(); fakeNameResolverFactory.setServers( - ImmutableList.of(new EquivalentAddressGroup(socketAddress, attrs))); + ImmutableList.of(new EquivalentAddressGroup(socketAddress, sslContextAttributes))); return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build())); } @@ -445,10 +866,49 @@ public void run() { } } }); - xdsClient.ldsResource.get(8000, TimeUnit.MILLISECONDS); + try { + xdsClient.ldsResource.get(8000, TimeUnit.MILLISECONDS); + } catch (Exception ex) { + // start() probably failed, so throw its exception + if (settableFuture.isDone()) { + Throwable t = settableFuture.get(); + if (t != null) { + throw new Exception(t); + } + } + throw ex; + } return settableFuture; } + private void setTrustStoreSystemProperties(String trustStoreFilePath) throws Exception { + System.setProperty("javax.net.ssl.trustStore", trustStoreFilePath); + System.setProperty("javax.net.ssl.trustStorePassword", "changeit"); + System.setProperty("javax.net.ssl.trustStoreType", "JKS"); + createDefaultTrustManager(); + } + + private void clearTrustStoreSystemProperties() throws Exception { + System.clearProperty("javax.net.ssl.trustStore"); + System.clearProperty("javax.net.ssl.trustStorePassword"); + System.clearProperty("javax.net.ssl.trustStoreType"); + createDefaultTrustManager(); + } + + /** + * Workaround the JDK's TrustManagerStore race. TrustManagerStore has a cache for the default + * certs based on the system properties. But updating the cache is not thread-safe and can cause a + * half-updated cache to appear fully-updated. When both the client and server initialize their + * trust store simultaneously, one can see a half-updated value. Creating the trust manager here + * fixes the cache while no other threads are running and thus the client and server threads won't + * race to update it. See https://github.com/grpc/grpc-java/issues/11678. + */ + private void createDefaultTrustManager() throws Exception { + TrustManagerFactory factory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + factory.init((KeyStore) null); + } + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { @Override @@ -520,7 +980,8 @@ public void refresh() { } void resolved() { - ResolutionResult.Builder builder = ResolutionResult.newBuilder().setAddresses(servers); + ResolutionResult.Builder builder = ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(servers)); listener.onResult(builder.build()); } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java index d28c7d7c607..ac990226259 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsServerTestHelper.buildTestListener; import static org.junit.Assert.fail; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; @@ -26,13 +27,16 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.SettableFuture; import io.grpc.BindableService; import io.grpc.InsecureServerCredentials; import io.grpc.ServerServiceDefinition; import io.grpc.Status; import io.grpc.StatusException; +import io.grpc.StatusOr; import io.grpc.testing.GrpcCleanupRule; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; @@ -40,7 +44,6 @@ import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.SocketAddress; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -81,6 +84,7 @@ private void buildBuilder(XdsServerBuilder.XdsServingStatusListener xdsServingSt XdsServerBuilder.forPort( port, XdsServerCredentials.create(InsecureServerCredentials.create())); builder.xdsClientPoolFactory(xdsClientPoolFactory); + builder.overrideBootstrapForTest(XdsServerTestHelper.RAW_BOOTSTRAP); if (xdsServingStatusListener != null) { builder.xdsServingStatusListener(xdsServingStatusListener); } @@ -135,7 +139,18 @@ public void run() { } } }); - xdsClient.ldsResource.get(5000, TimeUnit.MILLISECONDS); + try { + xdsClient.ldsResource.get(5000, TimeUnit.MILLISECONDS); + } catch (TimeoutException ex) { + // start() probably failed, so throw its exception + if (settableFuture.isDone()) { + Throwable t = settableFuture.get(); + if (t != null) { + throw new ExecutionException(t); + } + } + throw ex; + } return settableFuture; } @@ -195,13 +210,14 @@ public void xdsServer_discoverState() throws Exception { CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); future.get(5000, TimeUnit.MILLISECONDS); - xdsClient.ldsWatcher.onError(Status.ABORTED); + xdsClient.ldsWatcher.onAmbientError(Status.ABORTED); verify(mockXdsServingStatusListener, never()).onNotServing(any(StatusException.class)); reset(mockXdsServingStatusListener); - xdsClient.ldsWatcher.onError(Status.CANCELLED); + xdsClient.ldsWatcher.onAmbientError(Status.CANCELLED); verify(mockXdsServingStatusListener, never()).onNotServing(any(StatusException.class)); reset(mockXdsServingStatusListener); - xdsClient.ldsWatcher.onResourceDoesNotExist("not found error"); + Status notFoundStatus = Status.NOT_FOUND.withDescription("not found error"); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); verify(mockXdsServingStatusListener).onNotServing(any(StatusException.class)); reset(mockXdsServingStatusListener); XdsServerTestHelper.generateListenerUpdate( @@ -221,10 +237,13 @@ public void xdsServer_startError() buildServer(mockXdsServingStatusListener); Future future = startServerAsync(); // create port conflict for start to fail - XdsServerTestHelper.generateListenerUpdate( - xdsClient, + EnvoyServerProtoData.Listener listener = buildTestListener( + "listener1", "0.0.0.0:" + port, ImmutableList.of(), CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), - tlsContextManager); + null, tlsContextManager); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); + xdsClient.deliverLdsUpdate(listenerUpdate); + Throwable exception = future.get(5, TimeUnit.SECONDS); assertThat(exception).isInstanceOf(IOException.class); assertThat(exception).hasMessageThat().contains("Failed to bind"); @@ -249,7 +268,7 @@ public void xdsServerStartSecondUpdateAndError() tlsContextManager); verify(mockXdsServingStatusListener, never()).onNotServing(any(Throwable.class)); verifyServer(future, mockXdsServingStatusListener, null); - xdsClient.ldsWatcher.onError(Status.ABORTED); + xdsClient.ldsWatcher.onAmbientError(Status.ABORTED); verifyServer(null, mockXdsServingStatusListener, null); } @@ -298,9 +317,12 @@ public void drainGraceTime_negativeThrows() throws IOException { @Test public void testOverrideBootstrap() throws Exception { - Map b = new HashMap<>(); + Map b = XdsServerTestHelper.RAW_BOOTSTRAP; buildBuilder(null); builder.overrideBootstrapForTest(b); - assertThat(xdsClientPoolFactory.savedBootstrap).isEqualTo(b); + xdsServer = cleanupRule.register((XdsServerWrapper) builder.build()); + Future unused = startServerAsync(); + assertThat(xdsClientPoolFactory.savedBootstrapInfo.node().getId()) + .isEqualTo(XdsServerTestHelper.BOOTSTRAP_INFO.node().getId()); } } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index 791318c5355..386793299d8 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -21,7 +21,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.SettableFuture; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.grpc.InsecureChannelCredentials; +import io.grpc.MetricRecorder; +import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.internal.ObjectPool; import io.grpc.xds.EnvoyServerProtoData.ConnectionSourceType; import io.grpc.xds.EnvoyServerProtoData.FilterChain; @@ -35,8 +39,8 @@ import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.EnvoyProtoData; import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.client.XdsResourceType; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -44,7 +48,10 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; /** @@ -57,6 +64,17 @@ public class XdsServerTestHelper { "projects/42/networks/default/nodes/5c85b298-6f5b-4722-b74a-f7d1f0ccf5ad"; private static final EnvoyProtoData.Node BOOTSTRAP_NODE = EnvoyProtoData.Node.newBuilder().setId(NODE_ID).build(); + static final Map RAW_BOOTSTRAP = ImmutableMap.of( + "node", ImmutableMap.of( + "id", NODE_ID), + "server_listener_resource_name_template", "grpc/server?udpa.resource.listening_address=%s", + "xds_servers", ImmutableList.of( + ImmutableMap.of( + "server_uri", SERVER_URI, + "channel_creds", ImmutableList.of( + ImmutableMap.of( + "type", "insecure"))) + )); static final Bootstrapper.BootstrapInfo BOOTSTRAP_INFO = Bootstrapper.BootstrapInfo.builder() .servers(Arrays.asList( @@ -69,7 +87,7 @@ public class XdsServerTestHelper { static void generateListenerUpdate(FakeXdsClient xdsClient, EnvoyServerProtoData.DownstreamTlsContext tlsContext, TlsContextManager tlsContextManager) { - EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", + EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "0.0.0.0:0", ImmutableList.of(), tlsContext, null, tlsContextManager); LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); xdsClient.deliverLdsUpdate(listenerUpdate); @@ -80,7 +98,8 @@ static void generateListenerUpdate( EnvoyServerProtoData.DownstreamTlsContext tlsContext, EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, TlsContextManager tlsContextManager) { - EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", sourcePorts, + EnvoyServerProtoData.Listener listener = buildTestListener( + "listener1", "0.0.0.0:7000", sourcePorts, tlsContext, tlsContextForDefaultFilterChain, tlsContextManager); LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); xdsClient.deliverLdsUpdate(listenerUpdate); @@ -125,7 +144,7 @@ static EnvoyServerProtoData.Listener buildTestListener( tlsContextForDefaultFilterChain, tlsContextManager); EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener.create( - name, address, ImmutableList.of(filterChain1), defaultFilterChain); + name, address, ImmutableList.of(filterChain1), defaultFilterChain, Protocol.TCP); return listener; } @@ -133,17 +152,12 @@ static final class FakeXdsClientPoolFactory implements XdsClientPoolFactory { private XdsClient xdsClient; - Map savedBootstrap; + BootstrapInfo savedBootstrapInfo; FakeXdsClientPoolFactory(XdsClient xdsClient) { this.xdsClient = xdsClient; } - @Override - public void setBootstrapOverride(Map bootstrap) { - this.savedBootstrap = bootstrap; - } - @Override @Nullable public ObjectPool get(String target) { @@ -151,7 +165,9 @@ public ObjectPool get(String target) { } @Override - public ObjectPool getOrCreate(String target) throws XdsInitializationException { + public ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder) { + this.savedBootstrapInfo = bootstrapInfo; return new ObjectPool() { @Override public XdsClient getObject() { @@ -172,12 +188,18 @@ public List getTargets() { } } + // Implementation details: + // 1. Use `synchronized` in methods where XdsClientImpl uses its own `syncContext`. + // 2. Use `serverExecutor` via `execute()` in methods where XdsClientImpl uses watcher's executor. static final class FakeXdsClient extends XdsClient { - boolean shutdown; - SettableFuture ldsResource = SettableFuture.create(); - ResourceWatcher ldsWatcher; - CountDownLatch rdsCount = new CountDownLatch(1); + public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5); + + private boolean shutdown; + @Nullable SettableFuture ldsResource = SettableFuture.create(); + @Nullable ResourceWatcher ldsWatcher; + private CountDownLatch rdsCount = new CountDownLatch(1); final Map> rdsWatchers = new HashMap<>(); + @Nullable private volatile Executor serverExecutor; @Override public TlsContextManager getSecurityConfig() { @@ -191,14 +213,20 @@ public BootstrapInfo getBootstrapInfo() { @Override @SuppressWarnings("unchecked") - public void watchXdsResource(XdsResourceType resourceType, - String resourceName, - ResourceWatcher watcher, - Executor syncContext) { + public synchronized void watchXdsResource( + XdsResourceType resourceType, + String resourceName, + ResourceWatcher watcher, + Executor executor) { + if (serverExecutor != null) { + assertThat(executor).isEqualTo(serverExecutor); + } + switch (resourceType.typeName()) { case "LDS": assertThat(ldsWatcher).isNull(); ldsWatcher = (ResourceWatcher) watcher; + serverExecutor = executor; ldsResource.set(resourceName); break; case "RDS": @@ -211,14 +239,14 @@ public void watchXdsResource(XdsResourceType resou } @Override - public void cancelXdsResourceWatch(XdsResourceType type, - String resourceName, - ResourceWatcher watcher) { + public synchronized void cancelXdsResourceWatch( + XdsResourceType type, String resourceName, ResourceWatcher watcher) { switch (type.typeName()) { case "LDS": assertThat(ldsWatcher).isNotNull(); ldsResource = null; ldsWatcher = null; + serverExecutor = null; break; case "RDS": rdsWatchers.remove(resourceName); @@ -228,27 +256,92 @@ public void cancelXdsResourceWatch(XdsResourceType } @Override - public void shutdown() { + public synchronized void shutdown() { shutdown = true; } @Override - public boolean isShutDown() { + public synchronized boolean isShutDown() { return shutdown; } - void deliverLdsUpdate(List filterChains, - FilterChain defaultFilterChain) { - ldsWatcher.onChanged(LdsUpdate.forTcpListener(Listener.create( - "listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain))); + public void awaitRds(Duration timeout) throws InterruptedException, TimeoutException { + if (!rdsCount.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) { + throw new TimeoutException("Timeout " + timeout + " waiting for RDSs"); + } + } + + public void setExpectedRdsCount(int count) { + rdsCount = new CountDownLatch(count); + } + + private void execute(Runnable action) { + // This method ensures that all watcher updates: + // - Happen after the server started watching LDS. + // - Are executed within the sync context of the server. + // + // Note that this doesn't guarantee that any of the RDS watchers are created. + // Tests should use setExpectedRdsCount(int) and awaitRds() for that. + awaitLdsResource(DEFAULT_TIMEOUT); + serverExecutor.execute(action); + } + + private String awaitLdsResource(Duration timeout) { + if (ldsResource == null) { + throw new IllegalStateException("xDS resource update after watcher cancel"); + } + try { + return ldsResource.get(timeout.toMillis(), TimeUnit.MILLISECONDS); + } catch (ExecutionException | TimeoutException e) { + throw new RuntimeException("Can't resolve LDS resource name in " + timeout, e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + void deliverLdsUpdateWithApiListener(long httpMaxStreamDurationNano, + List virtualHosts) { + execute(() -> { + LdsUpdate update = LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( + httpMaxStreamDurationNano, virtualHosts, null)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(update)); + }); } void deliverLdsUpdate(LdsUpdate ldsUpdate) { - ldsWatcher.onChanged(ldsUpdate); + execute(() -> ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate))); + } + + void deliverLdsUpdate( + List filterChains, + @Nullable FilterChain defaultFilterChain) { + deliverLdsUpdate(LdsUpdate.forTcpListener(Listener.create("listener", "0.0.0.0:1", + ImmutableList.copyOf(filterChains), defaultFilterChain, Protocol.TCP))); + } + + void deliverLdsUpdate(FilterChain filterChain, @Nullable FilterChain defaultFilterChain) { + deliverLdsUpdate(ImmutableList.of(filterChain), defaultFilterChain); + } + + void deliverLdsResourceNotFound() { + String resourceName = awaitLdsResource(DEFAULT_TIMEOUT); + Status status = Status.NOT_FOUND.withDescription("Resource not found: " + resourceName); + execute(() -> ldsWatcher.onResourceChanged(StatusOr.fromStatus(status))); + } + + void deliverRdsUpdate(String resourceName, List virtualHosts) { + RdsUpdate update = new RdsUpdate(virtualHosts); + execute(() -> rdsWatchers.get(resourceName).onResourceChanged(StatusOr.fromValue(update))); + } + + void deliverRdsUpdate(String resourceName, VirtualHost virtualHost) { + deliverRdsUpdate(resourceName, ImmutableList.of(virtualHost)); } - void deliverRdsUpdate(String rdsName, List virtualHosts) { - rdsWatchers.get(rdsName).onChanged(new RdsUpdate(virtualHosts)); + void deliverRdsResourceNotFound(String resourceName) { + Status status = Status.NOT_FOUND.withDescription("Resource not found: " + resourceName); + execute(() -> rdsWatchers.get(resourceName).onResourceChanged(StatusOr.fromStatus(status))); } } } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index 55b8812cd17..99e3911307a 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -18,6 +18,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; import static io.grpc.xds.XdsServerWrapper.RETRY_DELAY_NANOS; import static org.junit.Assert.fail; @@ -31,11 +32,12 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.withSettings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.net.InetAddresses; import com.google.common.util.concurrent.SettableFuture; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.grpc.Attributes; import io.grpc.InsecureChannelCredentials; import io.grpc.Metadata; @@ -47,17 +49,22 @@ import io.grpc.ServerInterceptor; import io.grpc.Status; import io.grpc.StatusException; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.EnvoyServerProtoData.CidrRange; import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; +import io.grpc.xds.EnvoyServerProtoData.Listener; import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; -import io.grpc.xds.Filter.ServerInterceptorBuilder; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.StatefulFilter.Config; import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; @@ -72,11 +79,12 @@ import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import io.grpc.xds.internal.security.SslContextProviderSupplier; import java.io.IOException; +import java.net.InetAddress; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.concurrent.CountDownLatch; +import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -96,6 +104,14 @@ @RunWith(JUnit4.class) public class XdsServerWrapperTest { private static final int START_WAIT_AFTER_LISTENER_MILLIS = 100; + private static final String ROUTER_FILTER_INSTANCE_NAME = "envoy.router"; + private static final RouterFilter.Provider ROUTER_FILTER_PROVIDER = new RouterFilter.Provider(); + + // Readability: makes it simpler to distinguish resource parameters. + private static final ImmutableMap NO_FILTER_OVERRIDES = ImmutableMap.of(); + + private static final String STATEFUL_1 = "stateful_1"; + private static final String STATEFUL_2 = "stateful_2"; @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @@ -120,6 +136,7 @@ public void setup() { when(mockBuilder.build()).thenReturn(mockServer); xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, filterRegistry, executor.getScheduledExecutorService()); } @@ -142,7 +159,8 @@ public void testBootstrap() throws Exception { XdsListenerResource listenerResource = XdsListenerResource.getInstance(); when(xdsClient.getBootstrapInfo()).thenReturn(b); xdsServerWrapper = new XdsServerWrapper("[::FFFF:129.144.52.38]:80", mockBuilder, listener, - selectorManager, new FakeXdsClientPoolFactory(xdsClient), filterRegistry); + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, filterRegistry); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override public void run() { @@ -175,7 +193,8 @@ private void verifyBootstrapFail(Bootstrapper.BootstrapInfo b) throws Exception XdsClient xdsClient = mock(XdsClient.class); when(xdsClient.getBootstrapInfo()).thenReturn(b); xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, - selectorManager, new FakeXdsClientPoolFactory(xdsClient), filterRegistry); + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, filterRegistry); final SettableFuture start = SettableFuture.create(); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override @@ -214,7 +233,8 @@ public void testBootstrap_templateWithXdstp() throws Exception { XdsListenerResource listenerResource = XdsListenerResource.getInstance(); when(xdsClient.getBootstrapInfo()).thenReturn(b); xdsServerWrapper = new XdsServerWrapper("[::FFFF:129.144.52.38]:80", mockBuilder, listener, - selectorManager, new FakeXdsClientPoolFactory(xdsClient), filterRegistry); + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, filterRegistry); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override public void run() { @@ -254,7 +274,7 @@ public void run() { FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); FilterChain f1 = createFilterChain("filter-chain-1", createRds("rds")); xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); verify(listener, timeout(5000)).onServing(); @@ -263,7 +283,7 @@ public void run() { xdsServerWrapper.shutdown(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsClient.ldsResource).isNull(); - assertThat(xdsClient.shutdown).isTrue(); + assertThat(xdsClient.isShutDown()).isTrue(); verify(mockServer).shutdown(); assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue(); assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue(); @@ -305,7 +325,7 @@ public void run() { verify(mockServer, never()).start(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsClient.ldsResource).isNull(); - assertThat(xdsClient.shutdown).isTrue(); + assertThat(xdsClient.isShutDown()).isTrue(); verify(mockServer).shutdown(); assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue(); assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue(); @@ -326,7 +346,8 @@ public void run() { } }); String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + Status notFoundStatus = Status.NOT_FOUND.withDescription("Resource not found: " + ldsResource); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); verify(listener, timeout(5000)).onNotServing(any()); try { start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); @@ -344,7 +365,7 @@ public void run() { xdsServerWrapper.shutdown(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsClient.ldsResource).isNull(); - assertThat(xdsClient.shutdown).isTrue(); + assertThat(xdsClient.isShutDown()).isTrue(); verify(mockBuilder, times(1)).build(); verify(mockServer, times(1)).shutdown(); xdsServerWrapper.awaitTermination(1, TimeUnit.SECONDS); @@ -369,7 +390,7 @@ public void run() { FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); SslContextProviderSupplier sslSupplier = filterChain.sslContextProviderSupplier(); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); try { @@ -436,7 +457,7 @@ public void run() { xdsClient.ldsResource.get(5, TimeUnit.SECONDS); FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); try { @@ -515,7 +536,8 @@ public void run() { verify(mockServer).start(); // server shutdown after resourceDoesNotExist - xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + Status notFoundStatus = Status.NOT_FOUND.withDescription("Resource not found: " + ldsResource); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); verify(mockServer).shutdown(); // re-deliver lds resource @@ -526,6 +548,150 @@ public void run() { verify(mockServer).start(); } + @Test + public void onChanged_listenerIsNull() + throws ExecutionException, InterruptedException, TimeoutException { + xdsServerWrapper = new XdsServerWrapper("10.1.2.3:1", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, + filterRegistry, executor.getScheduledExecutorService()); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsResource).isEqualTo("grpc/server?udpa.resource.listening_address=10.1.2.3:1"); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + + xdsClient.deliverLdsUpdateWithApiListener(0L, Arrays.asList(virtualHost)); + + verify(listener, timeout(10000)).onNotServing(any()); + } + + @Test + public void onChanged_listenerAddressMissingPort() + throws ExecutionException, InterruptedException, TimeoutException { + xdsServerWrapper = new XdsServerWrapper("10.1.2.3:1", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, + filterRegistry, executor.getScheduledExecutorService()); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsResource).isEqualTo("grpc/server?udpa.resource.listening_address=10.1.2.3:1"); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain filterChain = EnvoyServerProtoData.FilterChain.create( + "filter-chain-foo", createMatch(), httpConnectionManager, createTls(), + mock(TlsContextManager.class)); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener( + Listener.create("listener", "20.3.4.5:", + ImmutableList.copyOf(Collections.singletonList(filterChain)), null, Protocol.TCP)); + + xdsClient.deliverLdsUpdate(listenerUpdate); + + verify(listener, timeout(10000)).onNotServing(any()); + } + + @Test + public void onChanged_listenerAddressMismatch() + throws ExecutionException, InterruptedException, TimeoutException { + xdsServerWrapper = new XdsServerWrapper("10.1.2.3:1", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, + filterRegistry, executor.getScheduledExecutorService()); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsResource).isEqualTo("grpc/server?udpa.resource.listening_address=10.1.2.3:1"); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain filterChain = EnvoyServerProtoData.FilterChain.create( + "filter-chain-foo", createMatch(), httpConnectionManager, createTls(), + mock(TlsContextManager.class)); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener( + Listener.create("listener", "20.3.4.5:1", + ImmutableList.copyOf(Collections.singletonList(filterChain)), null, Protocol.TCP)); + + xdsClient.deliverLdsUpdate(listenerUpdate); + + verify(listener, timeout(10000)).onNotServing(any()); + } + + @Test + public void onChanged_listenerAddressPortMismatch() + throws ExecutionException, InterruptedException, TimeoutException { + xdsServerWrapper = new XdsServerWrapper("10.1.2.3:1", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, + filterRegistry, executor.getScheduledExecutorService()); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsResource).isEqualTo("grpc/server?udpa.resource.listening_address=10.1.2.3:1"); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain filterChain = EnvoyServerProtoData.FilterChain.create( + "filter-chain-foo", createMatch(), httpConnectionManager, createTls(), + mock(TlsContextManager.class)); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener( + Listener.create("listener", "10.1.2.3:2", + ImmutableList.copyOf(Collections.singletonList(filterChain)), null, Protocol.TCP)); + + xdsClient.deliverLdsUpdate(listenerUpdate); + + verify(listener, timeout(10000)).onNotServing(any()); + } + @Test public void discoverState_rds() throws Exception { final SettableFuture start = SettableFuture.create(); @@ -546,7 +712,7 @@ public void run() { 0L, Collections.singletonList(virtualHost), new ArrayList()); EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); - xdsClient.rdsCount = new CountDownLatch(3); + xdsClient.setExpectedRdsCount(3); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); assertThat(start.isDone()).isFalse(); assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); @@ -558,7 +724,7 @@ public void run() { xdsClient.deliverLdsUpdate(Arrays.asList(f0, f2), f3); verify(mockServer, never()).start(); verify(listener, never()).onServing(); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("r1", Collections.singletonList(createVirtualHost("virtual-host-1"))); @@ -604,12 +770,11 @@ public void run() { EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r0")); - xdsClient.rdsCount = new CountDownLatch(1); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), f2); assertThat(start.isDone()).isFalse(); assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("r0", Collections.singletonList(createVirtualHost("virtual-host-0"))); start.get(5000, TimeUnit.MILLISECONDS); @@ -635,9 +800,9 @@ public void run() { EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0")); EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1")); EnvoyServerProtoData.FilterChain f5 = createFilterChain("filter-chain-4", createRds("r1")); - xdsClient.rdsCount = new CountDownLatch(1); + xdsClient.setExpectedRdsCount(1); xdsClient.deliverLdsUpdate(Arrays.asList(f5, f3), f4); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("r1", Collections.singletonList(createVirtualHost("virtual-host-1"))); xdsClient.deliverRdsUpdate("r0", @@ -690,8 +855,8 @@ public void run() { EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); - xdsClient.rdsCount.await(); - xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + xdsClient.rdsWatchers.get("r0").onResourceChanged(StatusOr.fromStatus(Status.CANCELLED)); start.get(5000, TimeUnit.MILLISECONDS); assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) .isEqualTo(2); @@ -711,13 +876,14 @@ public void run() { Collections.singletonList(createVirtualHost("virtual-host-1"))); assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); - xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); + xdsClient.rdsWatchers.get("r0").onAmbientError(Status.CANCELLED); realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); - xdsClient.rdsWatchers.get("r0").onResourceDoesNotExist("r0"); + Status notFoundStatus = Status.NOT_FOUND.withDescription("Resource r0 does not exist"); + xdsClient.rdsWatchers.get("r0").onResourceChanged(StatusOr.fromStatus(notFoundStatus)); realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); assertThat(realConfig.virtualHosts()).isEmpty(); assertThat(realConfig.interceptors()).isEmpty(); @@ -737,7 +903,9 @@ public void run() { } }); String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + Status notFoundStatus = Status.NOT_FOUND.withDescription( + "FakeXdsClient: Resource not found: " + ldsResource); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); verify(listener, timeout(5000)).onNotServing(any()); try { start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); @@ -751,10 +919,10 @@ public void run() { FilterChain filterChain0 = createFilterChain("filter-chain-0", createRds("rds")); SslContextProviderSupplier sslSupplier0 = filterChain0.sslContextProviderSupplier(); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain0), null); - xdsClient.ldsWatcher.onError(Status.INTERNAL); + ResourceWatcher saveRdsWatcher = xdsClient.rdsWatchers.get("rds"); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(Status.INTERNAL)); assertThat(selectorManager.getSelectorToUpdateSelector()) .isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); - ResourceWatcher saveRdsWatcher = xdsClient.rdsWatchers.get("rds"); verify(mockBuilder, times(1)).build(); verify(listener, times(2)).onNotServing(any(StatusException.class)); assertThat(sslSupplier0.isShutdown()).isFalse(); @@ -790,7 +958,7 @@ public void run() { xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-2"))); assertThat(sslSupplier1.isShutdown()).isFalse(); - xdsClient.ldsWatcher.onError(Status.DEADLINE_EXCEEDED); + xdsClient.ldsWatcher.onAmbientError(Status.DEADLINE_EXCEEDED); verify(mockBuilder, times(1)).build(); verify(mockServer, times(2)).start(); verify(listener, times(2)).onNotServing(any(StatusException.class)); @@ -805,17 +973,18 @@ public void run() { assertThat(sslSupplier1.isShutdown()).isFalse(); // not serving after serving - xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); assertThat(xdsClient.rdsWatchers).isEmpty(); - verify(mockServer, times(2)).shutdown(); + verify(mockServer, times(3)).shutdown(); // This is the 3rd shutdown in the test. when(mockServer.isShutdown()).thenReturn(true); assertThat(selectorManager.getSelectorToUpdateSelector()) .isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); verify(listener, times(3)).onNotServing(any(StatusException.class)); assertThat(sslSupplier1.isShutdown()).isTrue(); + assertThat(xdsClient.rdsWatchers.get("rds")).isNull(); // no op - saveRdsWatcher.onChanged( - new RdsUpdate(Collections.singletonList(createVirtualHost("virtual-host-1")))); + saveRdsWatcher.onResourceChanged(StatusOr.fromValue( + new RdsUpdate(Collections.singletonList(createVirtualHost("virtual-host-1"))))); verify(mockBuilder, times(1)).build(); verify(mockServer, times(2)).start(); verify(listener, times(1)).onServing(); @@ -844,8 +1013,8 @@ public void run() { assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); assertThat(executor.numPendingTasks()).isEqualTo(1); - xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); - verify(mockServer, times(3)).shutdown(); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); + verify(mockServer, times(4)).shutdown(); verify(listener, times(4)).onNotServing(any(StatusException.class)); verify(listener, times(1)).onNotServing(any(IOException.class)); when(mockServer.isShutdown()).thenReturn(true); @@ -873,7 +1042,7 @@ public void run() { assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); xdsServerWrapper.shutdown(); - verify(mockServer, times(4)).shutdown(); + verify(mockServer, times(5)).shutdown(); assertThat(sslSupplier3.isShutdown()).isTrue(); when(mockServer.awaitTermination(anyLong(), any(TimeUnit.class))).thenReturn(true); assertThat(xdsServerWrapper.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); @@ -957,9 +1126,11 @@ public void run() { new AtomicReference<>(routingConfig)).build()); when(serverCall.getAuthority()).thenReturn("not-match.google.com"); - Filter filter = mock(Filter.class); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + filterRegistry.register(filterProvider); + ServerCallHandler next = mock(ServerCallHandler.class); interceptor.interceptCall(serverCall, new Metadata(), next); verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); @@ -998,9 +1169,11 @@ public void run() { when(serverCall.getMethodDescriptor()).thenReturn(createMethod("NotMatchMethod")); when(serverCall.getAuthority()).thenReturn("foo.google.com"); - Filter filter = mock(Filter.class); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + filterRegistry.register(filterProvider); + ServerCallHandler next = mock(ServerCallHandler.class); interceptor.interceptCall(serverCall, new Metadata(), next); verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); @@ -1035,7 +1208,8 @@ public void run() { "/FooService/barMethod", "foo.google.com", Route.RouteAction.forCluster( - "cluster", Collections.emptyList(), null, null)); + "cluster", Collections.emptyList(), null, null, + false)); ServerCall serverCall = mock(ServerCall.class); when(serverCall.getAttributes()).thenReturn( Attributes.newBuilder() @@ -1043,9 +1217,11 @@ public void run() { when(serverCall.getMethodDescriptor()).thenReturn(createMethod("FooService/barMethod")); when(serverCall.getAuthority()).thenReturn("foo.google.com"); - Filter filter = mock(Filter.class); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + filterRegistry.register(filterProvider); + ServerCallHandler next = mock(ServerCallHandler.class); interceptor.interceptCall(serverCall, new Metadata(), next); verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); @@ -1112,10 +1288,14 @@ public void run() { RouteMatch.create( PathMatcher.fromPath("/FooService/barMethod", true), Collections.emptyList(), null); - Filter filter = mock(Filter.class, withSettings() - .extraInterfaces(ServerInterceptorBuilder.class)); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + + Filter filter = mock(Filter.class); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + when(filterProvider.newInstance(any(String.class))).thenReturn(filter); + filterRegistry.register(filterProvider); + FilterConfig f0 = mock(FilterConfig.class); FilterConfig f0Override = mock(FilterConfig.class); when(f0.typeUrl()).thenReturn("filter-type-url"); @@ -1136,10 +1316,8 @@ public ServerCall.Listener interceptCall(ServerCallof()); VirtualHost virtualHost = VirtualHost.create( @@ -1184,10 +1362,13 @@ public void run() { }); xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - Filter filter = mock(Filter.class, withSettings() - .extraInterfaces(ServerInterceptorBuilder.class)); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + Filter filter = mock(Filter.class); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + when(filterProvider.newInstance(any(String.class))).thenReturn(filter); + filterRegistry.register(filterProvider); + FilterConfig f0 = mock(FilterConfig.class); FilterConfig f0Override = mock(FilterConfig.class); when(f0.typeUrl()).thenReturn("filter-type-url"); @@ -1208,10 +1389,8 @@ public ServerCall.Listener interceptCall(ServerCall ServerCall.Listener interceptCall(ServerCall ServerCall.Listener interceptCall(ServerCall serverStart = filterStateTestStartServer(filterRegistry); + + VirtualHost vhost = filterStateTestVhost(); + + // LDS 1. + FilterChain lds1FilterChain = createFilterChain("chain_0", + createHcm(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2))); + xdsClient.deliverLdsUpdate(lds1FilterChain, null); + verifyServerStarted(serverStart); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + // Verify that StatefulFilter with different filter names result in different Filter instances. + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + assertThat(lds1Filter1).isNotSameInstanceAs(lds1Filter2); + // Redundant check just in case StatefulFilter synchronization is broken. + assertThat(lds1Filter1.idx).isEqualTo(0); + assertThat(lds1Filter2.idx).isEqualTo(1); + + // LDS 2: filter configs with the same names. + FilterChain lds2FilterChain = createFilterChain("chain_0", + createHcm(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2))); + xdsClient.deliverLdsUpdate(lds2FilterChain, null); + ImmutableList lds2Snapshot = statefulFilterProvider.getAllInstances(); + // Filter names hasn't changed, so expecting no new StatefulFilter instances. + assertWithMessage("LDS 2: Expected Filter instances to be reused across LDS updates") + .that(lds2Snapshot).isEqualTo(lds1Snapshot); + + // LDS 3: Filter "STATEFUL_2" removed. + FilterChain lds3FilterChain = createFilterChain("chain_0", + createHcm(vhost, filterStateTestConfigs(STATEFUL_1))); + xdsClient.deliverLdsUpdate(lds3FilterChain, null); + ImmutableList lds3Snapshot = statefulFilterProvider.getAllInstances(); + // Again, no new StatefulFilter instances should be created. + assertWithMessage("LDS 3: Expected Filter instances to be reused across LDS updates") + .that(lds3Snapshot).isEqualTo(lds1Snapshot); + // Verify the shutdown state. + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertWithMessage("LDS 3: Expected %s to be shut down", lds1Filter2) + .that(lds1Filter2.isShutdown()).isTrue(); + + // LDS 4: Filter "STATEFUL_2" added back. + FilterChain lds4FilterChain = createFilterChain("chain_0", + createHcm(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2))); + xdsClient.deliverLdsUpdate(lds4FilterChain, null); + ImmutableList lds4Snapshot = statefulFilterProvider.getAllInstances(); + // Filter "STATEFUL_2" should be treated as any other new filter name in an LDS update: + // a new instance should be created. + assertWithMessage("LDS 4: Expected a new filter instance for %s", STATEFUL_2) + .that(lds4Snapshot).hasSize(3); + StatefulFilter lds4Filter2 = lds4Snapshot.get(2); + assertThat(lds4Filter2.idx).isEqualTo(2); + assertThat(lds4Filter2).isNotSameInstanceAs(lds1Filter2); + assertThat(lds4Snapshot).containsAtLeastElementsIn(lds1Snapshot); + // Verify the shutdown state. + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertThat(lds1Filter2.isShutdown()).isTrue(); + assertThat(lds4Filter2.isShutdown()).isFalse(); + } + + @Test + public void filterState_survivesRds() throws Exception { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = filterStateTestFilterRegistry(statefulFilterProvider); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + String rdsName = "rds.example.com"; + + // LDS 1. + FilterChain fc1 = createFilterChain("fc1", + createHcmForRds(rdsName, filterStateTestConfigs(STATEFUL_1, STATEFUL_2))); + xdsClient.deliverLdsUpdate(fc1, null); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + verify(listener, never()).onServing(); + // Server didn't start, but filter instances should have already been created. + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + assertThat(lds1Filter1).isNotSameInstanceAs(lds1Filter2); + + // RDS 1. + VirtualHost vhost1 = filterStateTestVhost(); + xdsClient.deliverRdsUpdate(rdsName, vhost1); + verifyServerStarted(serverStart); + assertThat(getSelectorRoutingConfigs()).hasSize(1); + assertThat(getSelectorVhosts(fc1)).containsExactly(vhost1); + // Initial RDS update should not generate Filter instances. + ImmutableList rds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("RDS 1: Expected Filter instances to be reused across RDS route updates") + .that(rds1Snapshot).isEqualTo(lds1Snapshot); + + // RDS 2: exactly the same as RDS 1. + xdsClient.deliverRdsUpdate(rdsName, vhost1); + assertThat(getSelectorRoutingConfigs()).hasSize(1); + assertThat(getSelectorVhosts(fc1)).containsExactly(vhost1); + ImmutableList rds2Snapshot = statefulFilterProvider.getAllInstances(); + // Neither should any subsequent RDS updates. + assertWithMessage("RDS 2: Expected Filter instances to be reused across RDS route updates") + .that(rds2Snapshot).isEqualTo(lds1Snapshot); + + // RDS 3: Contains a per-route override for STATEFUL_1. + VirtualHost vhost3 = filterStateTestVhost(vhost1.name(), ImmutableMap.of( + STATEFUL_1, new Config("RDS3") + )); + xdsClient.deliverRdsUpdate(rdsName, vhost3); + assertThat(getSelectorRoutingConfigs()).hasSize(1); + assertThat(getSelectorVhosts(fc1)).containsExactly(vhost3); + ImmutableList rds3Snapshot = statefulFilterProvider.getAllInstances(); + // As with any other Route update, typed_per_filter_config overrides should not result in + // creating new filter instances. + assertWithMessage("RDS 3: Expected Filter instances to be reused on per-route filter overrides") + .that(rds3Snapshot).isEqualTo(lds1Snapshot); + } + + @Test + public void filterState_uniquePerFilterChain() { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = filterStateTestFilterRegistry(statefulFilterProvider); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + // Prepare multiple filter chains matchers for testing. + FilterChainMatch matcherA = createMatchSrcIp("3fff:a::/32"); + FilterChainMatch matcherB = createMatchSrcIp("3fff:b::/32"); + + // Vhosts won't change too. + VirtualHost vhostA = filterStateTestVhost("stateful_vhost_a"); + VirtualHost vhostB = filterStateTestVhost("stateful_vhost_b"); + + // LDS 1. + FilterChain lds1ChainA = createFilterChain("chain_a", + createHcm(vhostA, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)), + matcherA); + FilterChain lds1ChainB = createFilterChain("chain_b", + createHcm(vhostB, filterStateTestConfigs(STATEFUL_2)), + matcherB); + + xdsClient.deliverLdsUpdate(ImmutableList.of(lds1ChainA, lds1ChainB), null); + verifyServerStarted(serverStart); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + // Verify that filter with name STATEFUL_2 produced separate instances unique per filter chain. + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(3); + // Naming: ldsChainFilter + StatefulFilter lds1ChainAFilter1 = lds1Snapshot.get(0); + StatefulFilter lds1ChainAFilter2 = lds1Snapshot.get(1); + StatefulFilter lds1ChainBFilter2 = lds1Snapshot.get(2); + assertThat(lds1ChainAFilter2).isNotSameInstanceAs(lds1ChainBFilter2); + + // LDS 2: In chain B filter with name STATEFUL_1 is replaced STATEFUL_2. + FilterChain lds2ChainA = createFilterChain("chain_a", + createHcm(vhostA, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)), + matcherA); + FilterChain lds2ChainB = createFilterChain("chain_b", + createHcm(vhostB, filterStateTestConfigs(STATEFUL_1)), + matcherB); + + xdsClient.deliverLdsUpdate(ImmutableList.of(lds2ChainA, lds2ChainB), null); + ImmutableList lds2Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 2: expected a distinct instance of filter %s for Chain B", STATEFUL_1) + .that(lds2Snapshot).hasSize(4); + StatefulFilter lds2ChainBFilter1 = lds2Snapshot.get(3); + assertThat(lds2ChainBFilter1).isNotSameInstanceAs(lds1ChainAFilter1); + // Confirm correct STATEFUL_2 has been shut down. + assertThat(lds1ChainBFilter2.isShutdown()).isTrue(); + assertThat(lds1ChainAFilter2.isShutdown()).isFalse(); + + // LDS 3: Add default chain + // Default filter chain is an exception from the uniqueness rule, and we need to make sure + // that this is accounted for when we're tracking active filters per unique FilterChain. + FilterChain lds3ChainDefault = createFilterChain("chain_default", + createHcm(vhostA, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)), + matcherA); + xdsClient.deliverLdsUpdate(ImmutableList.of(lds2ChainA, lds2ChainB), lds3ChainDefault); + ImmutableList lds3Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 3: Expected two new distinct filter instances for default chain") + .that(lds3Snapshot).hasSize(6); + StatefulFilter lds3ChainDefaultFilter1 = lds3Snapshot.get(4); + StatefulFilter lds3ChainDefaultFilter2 = lds3Snapshot.get(5); + // STATEFUL_1 in default chain not the same STATEFUL_1 in chain A or B + assertThat(lds3ChainDefaultFilter1).isNotSameInstanceAs(lds1ChainAFilter1); + assertThat(lds3ChainDefaultFilter1).isNotSameInstanceAs(lds2ChainBFilter1); + // STATEFUL_2 in default chain not the same STATEFUL_1 in chain A + assertThat(lds3ChainDefaultFilter2).isNotSameInstanceAs(lds1ChainAFilter2); + } + + /** + * Verifies a special case where an existing filter is has a different typeUrl in a subsequent + * LDS update. + * + *

Expectations: + * 1. The old filter instance must be shutdown. + * 2. A new filter instance must be created for the new filter with different typeUrl. + */ + @Test + public void filterState_specialCase_sameNameDifferentTypeUrl() { + // Setup the server with filter containing StatefulFilter.Provider for two distict type URLs. + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + String altTypeUrl = "type.googleapis.com/grpc.test.AltStatefulFilter"; + StatefulFilter.Provider altStatefulFilterProvider = new StatefulFilter.Provider(altTypeUrl); + FilterRegistry filterRegistry = FilterRegistry.newRegistry() + .register(statefulFilterProvider, altStatefulFilterProvider, ROUTER_FILTER_PROVIDER); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + // Test a normal chain and the default chain, as it's handled separately. + VirtualHost vhost = filterStateTestVhost(); + + // LDS 1. + ImmutableList lds1Confgs = filterStateTestConfigs(STATEFUL_1, STATEFUL_2); + FilterChain lds1ChainA = createFilterChain("chain_a", createHcm(vhost, lds1Confgs)); + FilterChain lds1ChainDefault = createFilterChain("chain_default", createHcm(vhost, lds1Confgs)); + xdsClient.deliverLdsUpdate(lds1ChainA, lds1ChainDefault); + verifyServerStarted(serverStart); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(4); + // Naming: ldsChainFilter + StatefulFilter lds1ChainAFilter1 = lds1Snapshot.get(0); + StatefulFilter lds1ChainAFilter2 = lds1Snapshot.get(1); + StatefulFilter lds1ChainDefaultFilter1 = lds1Snapshot.get(2); + StatefulFilter lds1ChainDefaultFilter2 = lds1Snapshot.get(3); + + // LDS 2: Filter STATEFUL_2 present, but with a different typeUrl: altTypeUrl. + ImmutableList lds2Confgs = ImmutableList.of( + new NamedFilterConfig(STATEFUL_1, new StatefulFilter.Config(STATEFUL_1)), + new NamedFilterConfig(STATEFUL_2, new StatefulFilter.Config(STATEFUL_2, altTypeUrl)), + new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG) + ); + FilterChain lds2ChainA = createFilterChain("chain_a", createHcm(vhost, lds2Confgs)); + FilterChain lds2ChainDefault = createFilterChain("chain_default", createHcm(vhost, lds2Confgs)); + xdsClient.deliverLdsUpdate(lds2ChainA, lds2ChainDefault); + ImmutableList lds2Snapshot = statefulFilterProvider.getAllInstances(); + ImmutableList lds2SnapshotAlt = altStatefulFilterProvider.getAllInstances(); + // Filter "STATEFUL_2" has different typeUrl, and should be treated as a new filter. + // No changes in the snapshot of normal stateful filters. + assertThat(lds2Snapshot).isEqualTo(lds1Snapshot); + // Two new filter instances is created by altStatefulFilterProvider for chainA and chainDefault. + assertWithMessage("LDS 2: expected new filter instances for type %s", altTypeUrl) + .that(lds2SnapshotAlt).hasSize(2); + StatefulFilter lds2ChainAFilter2Alt = lds2SnapshotAlt.get(0); + StatefulFilter lds2ChainADefault2Alt = lds2SnapshotAlt.get(1); + // Confirm two new distict instances of STATEFUL_2 were created. + assertThat(lds2ChainAFilter2Alt).isNotSameInstanceAs(lds1ChainAFilter2); + assertThat(lds2ChainADefault2Alt).isNotSameInstanceAs(lds1ChainDefaultFilter2); + assertThat(lds2ChainAFilter2Alt).isNotSameInstanceAs(lds2ChainADefault2Alt); + // Verify the instance of STATEFUL_2 of the old type are shutdown. + assertThat(lds1ChainAFilter2.isShutdown()).isTrue(); + assertThat(lds1ChainDefaultFilter2.isShutdown()).isTrue(); + // Verify the new instances of STATEFUL_2 and the old instances of STATEFUL_1 are running. + assertThat(lds2ChainAFilter2Alt.isShutdown()).isFalse(); + assertThat(lds2ChainADefault2Alt.isShutdown()).isFalse(); + assertThat(lds1ChainAFilter1.isShutdown()).isFalse(); + assertThat(lds1ChainDefaultFilter1.isShutdown()).isFalse(); + } + + /** + * Verifies that all filter instances are shutdown (closed) on LDS resource not found. + */ + @Test + public void filterState_shutdown_onLdsNotFound() { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = filterStateTestFilterRegistry(statefulFilterProvider); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + // Test a normal chain and the default chain, as it's handled separately. + VirtualHost vhost = filterStateTestVhost(); + FilterChain chainA = createFilterChain("chain_a", + createHcm(vhost, filterStateTestConfigs(STATEFUL_1))); + FilterChain chainDefault = createFilterChain("chain_default", + createHcm(vhost, filterStateTestConfigs(STATEFUL_2))); + + // LDS 1. + xdsClient.deliverLdsUpdate(chainA, chainDefault); + verifyServerStarted(serverStart); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsChainFilter + StatefulFilter lds1ChainAFilter1 = lds1Snapshot.get(0); + StatefulFilter lds1ChainDefaultFilter2 = lds1Snapshot.get(1); + + // LDS 2: resource not found. + xdsClient.deliverLdsResourceNotFound(); + // Verify shutdown. + assertThat(lds1ChainAFilter1.isShutdown()).isTrue(); + assertThat(lds1ChainDefaultFilter2.isShutdown()).isTrue(); + } + + /** + * Verifies that all filter instances of a filter chain are shutdown when said chain is removed. + */ + @Test + public void filterState_shutdown_onChainRemoved() { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = filterStateTestFilterRegistry(statefulFilterProvider); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + ImmutableList configs = filterStateTestConfigs(STATEFUL_1, STATEFUL_2); + FilterChain chainA = createFilterChain("chain_a", + createHcm(filterStateTestVhost("stateful_vhost_a"), configs), + createMatchSrcIp("3fff:a::/32")); + FilterChain chainB = createFilterChain("chain_b", + createHcm(filterStateTestVhost("stateful_vhost_b"), configs), + createMatchSrcIp("3fff:b::/32")); + FilterChain chainDefault = createFilterChain("chain_default", + createHcm(filterStateTestVhost("stateful_vhost_default"), configs), + createMatchSrcIp("3fff:defa::/32")); + + // LDS 1. + xdsClient.deliverLdsUpdate(ImmutableList.of(chainA, chainB), chainDefault); + verifyServerStarted(serverStart); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(6); + StatefulFilter chainAFilter1 = lds1Snapshot.get(0); + StatefulFilter chainAFilter2 = lds1Snapshot.get(1); + StatefulFilter chainBFilter1 = lds1Snapshot.get(2); + StatefulFilter chainBFilter2 = lds1Snapshot.get(3); + StatefulFilter chainDefaultFilter1 = lds1Snapshot.get(4); + StatefulFilter chainDefaultFilter2 = lds1Snapshot.get(5); + + // LDS 2: ChainB and ChainDefault are gone. + xdsClient.deliverLdsUpdate(chainA, null); + assertThat(statefulFilterProvider.getAllInstances()).isEqualTo(lds1Snapshot); + // ChainA filters not shutdown (just in case). + assertThat(chainAFilter1.isShutdown()).isFalse(); + assertThat(chainAFilter2.isShutdown()).isFalse(); + // ChainB and ChainDefault filters shutdown. + assertWithMessage("chainBFilter1").that(chainBFilter1.isShutdown()).isTrue(); + assertWithMessage("chainBFilter2").that(chainBFilter2.isShutdown()).isTrue(); + assertWithMessage("chainDefaultFilter1").that(chainDefaultFilter1.isShutdown()).isTrue(); + assertWithMessage("chainDefaultFilter2").that(chainDefaultFilter2.isShutdown()).isTrue(); + } + + /** + * Verifies that all filter instances are shutdown (closed) on LDS ResourceWatcher shutdown. + */ + @Test + public void filterState_shutdown_onServerShutdown() { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = filterStateTestFilterRegistry(statefulFilterProvider); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + // Test a normal chain and the default chain, as it's handled separately. + VirtualHost vhost = filterStateTestVhost(); + FilterChain chainA = createFilterChain("chain_a", + createHcm(vhost, filterStateTestConfigs(STATEFUL_1))); + FilterChain chainDefault = createFilterChain("chain_default", + createHcm(vhost, filterStateTestConfigs(STATEFUL_2))); + + // LDS 1. + xdsClient.deliverLdsUpdate(chainA, chainDefault); + verifyServerStarted(serverStart); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsChainFilter + StatefulFilter lds1ChainAFilter1 = lds1Snapshot.get(0); + StatefulFilter lds1ChainDefaultFilter2 = lds1Snapshot.get(1); + + // Shutdown. + xdsServerWrapper.shutdown(); + assertThat(xdsServerWrapper.isShutdown()).isTrue(); + assertThat(xdsClient.isShutDown()).isTrue(); + // Verify shutdown. + assertThat(lds1ChainAFilter1.isShutdown()).isTrue(); + assertThat(lds1ChainDefaultFilter2.isShutdown()).isTrue(); + } + + /** + * Verifies that filter instances are NOT shutdown on RDS_RESOURCE_NAME not found. + */ + @Test + public void filterState_shutdown_noShutdownOnRdsNotFound() throws Exception { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = filterStateTestFilterRegistry(statefulFilterProvider); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + String rdsName = "rds.example.com"; + // Test a normal chain and the default chain, as it's handled separately. + FilterChain chainA = createFilterChain("chain_a", + createHcmForRds(rdsName, filterStateTestConfigs(STATEFUL_1))); + FilterChain chainDefault = createFilterChain("chain_default", + createHcmForRds(rdsName, filterStateTestConfigs(STATEFUL_2))); + + xdsClient.deliverLdsUpdate(chainA, chainDefault); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + verify(listener, never()).onServing(); + // Server didn't start, but filter instances should have already been created. + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsChainFilter + StatefulFilter lds1ChainAFilter1 = lds1Snapshot.get(0); + StatefulFilter lds1ChainDefaultFilter2 = lds1Snapshot.get(1); + + // RDS 1: Standard vhost with a route. + xdsClient.deliverRdsUpdate(rdsName, filterStateTestVhost()); + verifyServerStarted(serverStart); + assertThat(statefulFilterProvider.getAllInstances()).isEqualTo(lds1Snapshot); + + // RDS 2: RDS_RESOURCE_NAME not found. + xdsClient.deliverRdsResourceNotFound(rdsName); + assertThat(lds1ChainAFilter1.isShutdown()).isFalse(); + assertThat(lds1ChainDefaultFilter2.isShutdown()).isFalse(); + } + + private FilterRegistry filterStateTestFilterRegistry( + StatefulFilter.Provider statefulFilterProvider) { + return FilterRegistry.newRegistry().register(statefulFilterProvider, ROUTER_FILTER_PROVIDER); + } + + private SettableFuture filterStateTestStartServer(FilterRegistry filterRegistry) { + xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, filterRegistry); + SettableFuture serverStart = SettableFuture.create(); + scheduleServerStart(xdsServerWrapper, serverStart); + return serverStart; + } + + private static ImmutableList filterStateTestConfigs(String... names) { + ImmutableList.Builder result = ImmutableList.builder(); + for (String name : names) { + result.add(new NamedFilterConfig(name, new StatefulFilter.Config(name))); + } + result.add(new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG)); + return result.build(); + } + + private static Route filterStateTestRoute(ImmutableMap perRouteOverrides) { + // Standard basic route for filterState tests. + return Route.forAction( + RouteMatch.withPathExactOnly("/grpc.test.HelloService/SayHello"), null, perRouteOverrides); + } + + private static VirtualHost filterStateTestVhost() { + return filterStateTestVhost("stateful-vhost", NO_FILTER_OVERRIDES); + } + + private static VirtualHost filterStateTestVhost(String name) { + return filterStateTestVhost(name, NO_FILTER_OVERRIDES); + } + + private static VirtualHost filterStateTestVhost( + String name, ImmutableMap perRouteOverrides) { + return VirtualHost.create( + name, + ImmutableList.of("stateful.test.example.com"), + ImmutableList.of(filterStateTestRoute(perRouteOverrides)), + NO_FILTER_OVERRIDES); + } + + // End filter state tests. + + private void verifyServerStarted(SettableFuture serverStart) { + try { + serverStart.get(5, TimeUnit.SECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + throw new AssertionError("serverStart future failed to resolve within the timeout", e); + } + verify(listener).onServing(); + try { + verify(mockServer).start(); + } catch (IOException e) { + throw new AssertionError("mockServer.start() shouldn't throw", e); + } + } + + private Map> getSelectorRoutingConfigs() { + return selectorManager.getSelectorToUpdateSelector().getRoutingConfigs(); + } + + private ServerRoutingConfig getSelectorRoutingConfig(FilterChain fc) { + return getSelectorRoutingConfigs().get(fc).get(); + } + + private ImmutableList getSelectorVhosts(FilterChain fc) { + return getSelectorRoutingConfig(fc).virtualHosts(); + } + + public static void scheduleServerStart( + XdsServerWrapper xdsServerWrapper, SettableFuture serverStart) { + Executors.newSingleThreadExecutor().execute(() -> { + try { + serverStart.set(xdsServerWrapper.start()); + } catch (Exception e) { + serverStart.setException(e); + } + }); + } + private static FilterChain createFilterChain(String name, HttpConnectionManager hcm) { - return EnvoyServerProtoData.FilterChain.create(name, createMatch(), - hcm, createTls(), mock(TlsContextManager.class)); + return createFilterChain(name, hcm, createMatch()); + } + + private static FilterChain createFilterChain( + String name, HttpConnectionManager hcm, FilterChainMatch filterChainMatch) { + TlsContextManager tlsContextManager = mock(TlsContextManager.class); + return FilterChain.create(name, filterChainMatch, hcm, createTls(), tlsContextManager); } private static VirtualHost createVirtualHost(String name) { @@ -1273,17 +1952,27 @@ private static VirtualHost createVirtualHost(String name) { ImmutableMap.of()); } - private static HttpConnectionManager createRds(String name) { - return createRds(name, null); + private static HttpConnectionManager createHcm( + VirtualHost vhost, List filterConfigs) { + return HttpConnectionManager.forVirtualHosts(0L, ImmutableList.of(vhost), filterConfigs); + } + + private static HttpConnectionManager createHcmForRds( + String name, List filterConfigs) { + return HttpConnectionManager.forRdsName(0L, name, filterConfigs); } - private static HttpConnectionManager createRds(String name, FilterConfig filterConfig) { - return HttpConnectionManager.forRdsName(0L, name, - Arrays.asList(new NamedFilterConfig("named-config-" + name, filterConfig))); + private static HttpConnectionManager createRds(String name) { + NamedFilterConfig config = + new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG); + return createHcmForRds(name, ImmutableList.of(config)); } - private static EnvoyServerProtoData.FilterChainMatch createMatch() { - return EnvoyServerProtoData.FilterChainMatch.create( + /** + * Returns the least-specific match-all Filter Chain Match. + */ + static FilterChainMatch createMatch() { + return FilterChainMatch.create( 0, ImmutableList.of(), ImmutableList.of(), @@ -1294,6 +1983,21 @@ private static EnvoyServerProtoData.FilterChainMatch createMatch() { ""); } + private static FilterChainMatch createMatchSrcIp(String srcCidr) { + String[] srcParts = srcCidr.split("/", 2); + InetAddress ip = InetAddresses.forString(srcParts[0]); + Integer subnetMask = Integer.valueOf(srcParts[1], 10); + return FilterChainMatch.create( + 0, + ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of(CidrRange.create(ip, subnetMask)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + ImmutableList.of(), + ImmutableList.of(), + ""); + } + private static ServerRoutingConfig createRoutingConfig(String path, String domain) { return createRoutingConfig(path, domain, null); } @@ -1323,7 +2027,7 @@ private static MethodDescriptor createMethod(String path) { .build(); } - private static EnvoyServerProtoData.DownstreamTlsContext createTls() { + static EnvoyServerProtoData.DownstreamTlsContext createTls() { return CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); } } diff --git a/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java b/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java index cc12e3863ba..a54893c9075 100644 --- a/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java +++ b/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java @@ -106,7 +106,7 @@ public void setXdsConfig(final String type, final Map copyResources = new HashMap<>(resources); xdsResources.put(type, copyResources); - String newVersionInfo = String.valueOf(xdsVersions.get(type).getAndDecrement()); + String newVersionInfo = String.valueOf(xdsVersions.get(type).getAndIncrement()); for (Map.Entry, Set> entry : subscribers.get(type).entrySet()) { @@ -119,6 +119,11 @@ public void run() { }); } + ImmutableMap getCurrentConfig(String type) { + HashMap hashMap = xdsResources.get(type); + return (hashMap != null) ? ImmutableMap.copyOf(hashMap) : ImmutableMap.of(); + } + @Override public StreamObserver streamAggregatedResources( final StreamObserver responseObserver) { @@ -159,7 +164,7 @@ public void run() { DiscoveryResponse response = generateResponse(resourceType, String.valueOf(xdsVersions.get(resourceType)), - String.valueOf(xdsNonces.get(resourceType).get(responseObserver)), + String.valueOf(xdsNonces.get(resourceType).get(responseObserver).addAndGet(1)), requestedResourceNames); responseObserver.onNext(response); subscribers.get(resourceType).put(responseObserver, requestedResourceNames); @@ -202,4 +207,12 @@ private DiscoveryResponse generateResponse(String resourceType, String version, } return responseBuilder.build(); } + + public Map getSubscriberCounts() { + Map subscriberCounts = new HashMap<>(); + for (String type : subscribers.keySet()) { + subscriberCounts.put(type, subscribers.get(type).size()); + } + return subscriberCounts; + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsTestUtils.java b/xds/src/test/java/io/grpc/xds/XdsTestUtils.java new file mode 100644 index 00000000000..f81957ee311 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/XdsTestUtils.java @@ -0,0 +1,437 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; + +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.protobuf.Any; +import com.google.protobuf.Message; +import com.google.protobuf.util.Durations; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterStats; +import io.envoyproxy.envoy.config.listener.v3.ApiListener; +import io.envoyproxy.envoy.config.listener.v3.Listener; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; +import io.envoyproxy.envoy.extensions.clusters.aggregate.v3.ClusterConfig; +import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; +import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; +import io.envoyproxy.envoy.service.load_stats.v3.LoadReportingServiceGrpc; +import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsRequest; +import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsResponse; +import io.grpc.BindableService; +import io.grpc.Context; +import io.grpc.Context.CancellationListener; +import io.grpc.InsecureChannelCredentials; +import io.grpc.StatusOr; +import io.grpc.internal.ExponentialBackoffPolicy; +import io.grpc.internal.FakeClock; +import io.grpc.internal.JsonParser; +import io.grpc.stub.StreamObserver; +import io.grpc.xds.Endpoints.LbEndpoint; +import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.XdsConfig.XdsClusterConfig.EndpointConfig; +import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; +import io.grpc.xds.client.Locality; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClientMetricReporter; +import io.grpc.xds.client.XdsResourceType; +import io.grpc.xds.client.XdsTransportFactory; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.mockito.ArgumentMatcher; +import org.mockito.InOrder; + +public class XdsTestUtils { + private static final Logger log = Logger.getLogger(XdsTestUtils.class.getName()); + static final String RDS_NAME = "route-config.googleapis.com"; + static final String CLUSTER_NAME = "cluster0"; + static final String EDS_NAME = "eds-service-0"; + static final String SERVER_LISTENER = "grpc/server?udpa.resource.listening_address="; + static final String HTTP_CONNECTION_MANAGER_TYPE_URL = + "type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3" + + ".HttpConnectionManager"; + static final Bootstrapper.ServerInfo EMPTY_BOOTSTRAPPER_SERVER_INFO = + Bootstrapper.ServerInfo.create( + "td.googleapis.com", InsecureChannelCredentials.create(), false, true, false, false); + public static final String ENDPOINT_HOSTNAME = "data-host"; + public static final int ENDPOINT_PORT = 1234; + + static BindableService createLrsService(AtomicBoolean lrsEnded, + Queue loadReportCalls) { + return new LoadReportingServiceGrpc.LoadReportingServiceImplBase() { + @Override + public StreamObserver streamLoadStats( + StreamObserver responseObserver) { + assertThat(lrsEnded.get()).isTrue(); + lrsEnded.set(false); + @SuppressWarnings("unchecked") + StreamObserver requestObserver = mock(StreamObserver.class); + LrsRpcCall call = new LrsRpcCall(requestObserver, responseObserver); + Context.current().addListener( + new CancellationListener() { + @Override + public void cancelled(Context context) { + lrsEnded.set(true); + } + }, MoreExecutors.directExecutor()); + loadReportCalls.offer(call); + return requestObserver; + } + }; + } + + static boolean matchErrorDetail( + com.google.rpc.Status errorDetail, int expectedCode, List expectedMessages) { + if (expectedCode != errorDetail.getCode()) { + return false; + } + List errors = Splitter.on('\n').splitToList(errorDetail.getMessage()); + if (errors.size() != expectedMessages.size()) { + return false; + } + for (int i = 0; i < errors.size(); i++) { + if (!errors.get(i).startsWith(expectedMessages.get(i))) { + return false; + } + } + return true; + } + + static void setAdsConfig(XdsTestControlPlaneService service, String serverName) { + setAdsConfig(service, serverName, RDS_NAME, CLUSTER_NAME, EDS_NAME, ENDPOINT_HOSTNAME, + ENDPOINT_PORT); + } + + static void setAdsConfig(XdsTestControlPlaneService service, String serverName, String rdsName, + String clusterName, String edsName, String endpointHostname, + int endpointPort) { + + Listener serverListener = ControlPlaneRule.buildServerListener(); + Listener clientListener = ControlPlaneRule.buildClientListener(serverName, rdsName); + service.setXdsConfig(ADS_TYPE_URL_LDS, + ImmutableMap.of(SERVER_LISTENER, serverListener, serverName, clientListener)); + + RouteConfiguration routeConfig = + buildRouteConfiguration(serverName, rdsName, clusterName); + service.setXdsConfig(ADS_TYPE_URL_RDS, ImmutableMap.of(rdsName, routeConfig));; + + Cluster cluster = ControlPlaneRule.buildCluster(clusterName, edsName); + service.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(clusterName, cluster)); + + ClusterLoadAssignment clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.0.11", endpointHostname, endpointPort, edsName); + service.setXdsConfig(ADS_TYPE_URL_EDS, + ImmutableMap.of(edsName, clusterLoadAssignment)); + + log.log(Level.FINE, String.format("Set ADS config for %s with address %s:%d", + serverName, endpointHostname, endpointPort)); + + } + + static String getEdsNameForCluster(String clusterName) { + return "eds_" + clusterName; + } + + static void setAggregateCdsConfig(XdsTestControlPlaneService service, String serverName, + String clusterName, List children) { + Map clusterMap = new HashMap<>(); + + ClusterConfig rootConfig = ClusterConfig.newBuilder().addAllClusters(children).build(); + Cluster.CustomClusterType type = + Cluster.CustomClusterType.newBuilder() + .setName(XdsClusterResource.AGGREGATE_CLUSTER_TYPE_NAME) + .setTypedConfig(Any.pack(rootConfig)) + .build(); + Cluster.Builder builder = Cluster.newBuilder().setName(clusterName).setClusterType(type); + builder.setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN); + Cluster cluster = builder.build(); + clusterMap.put(clusterName, cluster); + + for (String child : children) { + Cluster childCluster = ControlPlaneRule.buildCluster(child, getEdsNameForCluster(child)); + clusterMap.put(child, childCluster); + } + + service.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + + Map edsMap = new HashMap<>(); + for (String child : children) { + ClusterLoadAssignment clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.0.16", ENDPOINT_HOSTNAME, ENDPOINT_PORT, getEdsNameForCluster(child)); + edsMap.put(getEdsNameForCluster(child), clusterLoadAssignment); + } + service.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + } + + static void addAggregateToExistingConfig(XdsTestControlPlaneService service, String rootName, + List children) { + Map clusterMap = new HashMap<>(service.getCurrentConfig(ADS_TYPE_URL_CDS)); + if (clusterMap.containsKey(rootName)) { + throw new IllegalArgumentException("Root cluster " + rootName + " already exists"); + } + ClusterConfig rootConfig = ClusterConfig.newBuilder().addAllClusters(children).build(); + Cluster.CustomClusterType type = + Cluster.CustomClusterType.newBuilder() + .setName(XdsClusterResource.AGGREGATE_CLUSTER_TYPE_NAME) + .setTypedConfig(Any.pack(rootConfig)) + .build(); + Cluster.Builder builder = Cluster.newBuilder().setName(rootName).setClusterType(type); + builder.setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN); + Cluster cluster = builder.build(); + clusterMap.put(rootName, cluster); + + for (String child : children) { + if (clusterMap.containsKey(child)) { + continue; + } + Cluster childCluster = ControlPlaneRule.buildCluster(child, getEdsNameForCluster(child)); + clusterMap.put(child, childCluster); + } + + service.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + + Map edsMap = new HashMap<>(service.getCurrentConfig(ADS_TYPE_URL_EDS)); + for (String child : children) { + if (edsMap.containsKey(getEdsNameForCluster(child))) { + continue; + } + ClusterLoadAssignment clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.0.15", ENDPOINT_HOSTNAME, ENDPOINT_PORT, getEdsNameForCluster(child)); + edsMap.put(getEdsNameForCluster(child), clusterLoadAssignment); + } + service.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + } + + static XdsConfig getDefaultXdsConfig(String serverHostName) + throws XdsResourceType.ResourceInvalidException, IOException { + XdsConfig.XdsConfigBuilder builder = new XdsConfig.XdsConfigBuilder(); + + Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig( + "terminal-filter", RouterFilter.ROUTER_CONFIG); + + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forRdsName( + 0L, RDS_NAME, Collections.singletonList(routerFilterConfig)); + XdsListenerResource.LdsUpdate ldsUpdate = + XdsListenerResource.LdsUpdate.forApiListener(httpConnectionManager); + + RouteConfiguration routeConfiguration = + buildRouteConfiguration(serverHostName, RDS_NAME, CLUSTER_NAME); + XdsResourceType.Args args = new XdsResourceType.Args( + EMPTY_BOOTSTRAPPER_SERVER_INFO, "0", "0", null, null, null); + XdsRouteConfigureResource.RdsUpdate rdsUpdate = + XdsRouteConfigureResource.getInstance().doParse(args, routeConfiguration); + + // Take advantage of knowing that there is only 1 virtual host in the route configuration + assertThat(rdsUpdate.virtualHosts).hasSize(1); + VirtualHost virtualHost = rdsUpdate.virtualHosts.get(0); + + // Need to create endpoints to create locality endpoints map to create edsUpdate + Map lbEndpointsMap = new HashMap<>(); + LbEndpoint lbEndpoint = LbEndpoint.create( + "127.0.0.11", ENDPOINT_PORT, 0, true, ENDPOINT_HOSTNAME, ImmutableMap.of()); + lbEndpointsMap.put( + Locality.create("", "", ""), + LocalityLbEndpoints.create(ImmutableList.of(lbEndpoint), 10, 0, ImmutableMap.of())); + + // Need to create EdsUpdate to create CdsUpdate to create XdsClusterConfig for builder + XdsEndpointResource.EdsUpdate edsUpdate = new XdsEndpointResource.EdsUpdate( + EDS_NAME, lbEndpointsMap, Collections.emptyList()); + XdsClusterResource.CdsUpdate cdsUpdate = XdsClusterResource.CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false, null) + .lbPolicyConfig(getWrrLbConfigAsMap()).build(); + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate))); + + builder + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(virtualHost) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)); + + return builder.build(); + } + + static Map createMinimalLbEndpointsMap(String serverAddress) { + Map lbEndpointsMap = new HashMap<>(); + LbEndpoint lbEndpoint = LbEndpoint.create( + serverAddress, ENDPOINT_PORT, 0, true, ENDPOINT_HOSTNAME, ImmutableMap.of()); + lbEndpointsMap.put( + Locality.create("", "", ""), + LocalityLbEndpoints.create(ImmutableList.of(lbEndpoint), 10, 0, ImmutableMap.of())); + return lbEndpointsMap; + } + + @SuppressWarnings("unchecked") + static ImmutableMap getWrrLbConfigAsMap() throws IOException { + String lbConfigStr = "{\"wrr_locality_experimental\" : " + + "{ \"childPolicy\" : [{\"round_robin\" : {}}]}}"; + + return ImmutableMap.copyOf((Map) JsonParser.parse(lbConfigStr)); + } + + static RouteConfiguration buildRouteConfiguration(String authority, String rdsName, + String clusterName) { + return ControlPlaneRule.buildRouteConfiguration(authority, rdsName, clusterName); + } + + static Cluster buildAggCluster(String name, List childNames) { + ClusterConfig rootConfig = ClusterConfig.newBuilder().addAllClusters(childNames).build(); + Cluster.CustomClusterType type = + Cluster.CustomClusterType.newBuilder() + .setName(XdsClusterResource.AGGREGATE_CLUSTER_TYPE_NAME) + .setTypedConfig(Any.pack(rootConfig)) + .build(); + Cluster.Builder builder = + Cluster.newBuilder().setName(name).setClusterType(type); + builder.setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN); + Cluster cluster = builder.build(); + return cluster; + } + + static void addEdsClusters(Map clusterMap, Map edsMap, + String... clusterNames) { + for (String clusterName : clusterNames) { + String edsName = getEdsNameForCluster(clusterName); + Cluster cluster = ControlPlaneRule.buildCluster(clusterName, edsName); + clusterMap.put(clusterName, cluster); + + ClusterLoadAssignment clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.0.13", ENDPOINT_HOSTNAME, ENDPOINT_PORT, edsName); + edsMap.put(edsName, clusterLoadAssignment); + } + } + + static Listener buildInlineClientListener(String rdsName, String clusterName, String serverName) { + HttpFilter + httpFilter = HttpFilter.newBuilder() + .setName("terminal-filter") + .setTypedConfig(Any.pack(Router.newBuilder().build())) + .setIsOptional(true) + .build(); + ApiListener.Builder clientListenerBuilder = + ApiListener.newBuilder().setApiListener(Any.pack( + io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3 + .HttpConnectionManager.newBuilder() + .setRouteConfig( + buildRouteConfiguration(serverName, rdsName, clusterName)) + .addAllHttpFilters(Collections.singletonList(httpFilter)) + .build(), + HTTP_CONNECTION_MANAGER_TYPE_URL)); + return Listener.newBuilder() + .setName(serverName) + .setApiListener(clientListenerBuilder.build()).build(); + } + + public static XdsClient createXdsClient( + List serverUris, + XdsTransportFactory xdsTransportFactory, + FakeClock fakeClock) { + return createXdsClient( + CommonBootstrapperTestUtils.buildBootStrap(serverUris), + xdsTransportFactory, + fakeClock, + new XdsClientMetricReporter() {}); + } + + /** Calls {@link CommonBootstrapperTestUtils#createXdsClient} with gRPC-specific values. */ + public static XdsClient createXdsClient( + Bootstrapper.BootstrapInfo bootstrapInfo, + XdsTransportFactory xdsTransportFactory, + FakeClock fakeClock, + XdsClientMetricReporter xdsClientMetricReporter) { + return CommonBootstrapperTestUtils.createXdsClient( + bootstrapInfo, + xdsTransportFactory, + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); + } + + /** + * Matches a {@link LoadStatsRequest} containing a collection of {@link ClusterStats} with + * the same list of clusterName:clusterServiceName pair. + */ + static class LrsRequestMatcher implements ArgumentMatcher { + private final List expected; + + private LrsRequestMatcher(List clusterNames) { + expected = new ArrayList<>(); + for (String[] pair : clusterNames) { + expected.add(pair[0] + ":" + (pair[1] == null ? "" : pair[1])); + } + Collections.sort(expected); + } + + @Override + public boolean matches(LoadStatsRequest argument) { + List actual = new ArrayList<>(); + for (ClusterStats clusterStats : argument.getClusterStatsList()) { + actual.add(clusterStats.getClusterName() + ":" + clusterStats.getClusterServiceName()); + } + Collections.sort(actual); + return actual.equals(expected); + } + } + + static class LrsRpcCall { + private final StreamObserver requestObserver; + private final StreamObserver responseObserver; + private final InOrder inOrder; + + private LrsRpcCall(StreamObserver requestObserver, + StreamObserver responseObserver) { + this.requestObserver = requestObserver; + this.responseObserver = responseObserver; + inOrder = inOrder(requestObserver); + } + + protected void verifyNextReportClusters(List clusters) { + inOrder.verify(requestObserver).onNext(argThat(new LrsRequestMatcher(clusters))); + } + + protected void sendResponse(List clusters, long loadReportIntervalNano) { + LoadStatsResponse response = + LoadStatsResponse.newBuilder() + .addAllClusters(clusters) + .setLoadReportingInterval(Durations.fromNanos(loadReportIntervalNano)) + .build(); + responseObserver.onNext(response); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/client/BackendMetricPropagationTest.java b/xds/src/test/java/io/grpc/xds/client/BackendMetricPropagationTest.java new file mode 100644 index 00000000000..31ad6f9c47f --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/client/BackendMetricPropagationTest.java @@ -0,0 +1,151 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.client; + +import static com.google.common.truth.Truth.assertThat; +import static java.util.Arrays.asList; + +import com.google.common.collect.ImmutableList; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link BackendMetricPropagation}. + */ +@RunWith(JUnit4.class) +public class BackendMetricPropagationTest { + + @Test + public void fromMetricSpecs_nullInput() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs(null); + + assertThat(config.propagateCpuUtilization).isFalse(); + assertThat(config.propagateMemUtilization).isFalse(); + assertThat(config.propagateApplicationUtilization).isFalse(); + assertThat(config.shouldPropagateNamedMetric("any")).isFalse(); + } + + @Test + public void fromMetricSpecs_emptyInput() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs(ImmutableList.of()); + + assertThat(config.propagateCpuUtilization).isFalse(); + assertThat(config.propagateMemUtilization).isFalse(); + assertThat(config.propagateApplicationUtilization).isFalse(); + assertThat(config.shouldPropagateNamedMetric("any")).isFalse(); + } + + @Test + public void fromMetricSpecs_partialStandardMetrics() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of("cpu_utilization", "mem_utilization")); + + assertThat(config.propagateCpuUtilization).isTrue(); + assertThat(config.propagateMemUtilization).isTrue(); + assertThat(config.propagateApplicationUtilization).isFalse(); + assertThat(config.shouldPropagateNamedMetric("any")).isFalse(); + } + + @Test + public void fromMetricSpecs_allStandardMetrics() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of("cpu_utilization", "mem_utilization", "application_utilization")); + + assertThat(config.propagateCpuUtilization).isTrue(); + assertThat(config.propagateMemUtilization).isTrue(); + assertThat(config.propagateApplicationUtilization).isTrue(); + assertThat(config.shouldPropagateNamedMetric("any")).isFalse(); + } + + @Test + public void fromMetricSpecs_wildcardNamedMetrics() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of("named_metrics.*")); + + assertThat(config.propagateCpuUtilization).isFalse(); + assertThat(config.propagateMemUtilization).isFalse(); + assertThat(config.propagateApplicationUtilization).isFalse(); + assertThat(config.shouldPropagateNamedMetric("any_key")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("another_key")).isTrue(); + } + + @Test + public void fromMetricSpecs_specificNamedMetrics() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of("named_metrics.foo", "named_metrics.bar")); + + assertThat(config.shouldPropagateNamedMetric("foo")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("bar")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("baz")).isFalse(); + assertThat(config.shouldPropagateNamedMetric("any")).isFalse(); + } + + @Test + public void fromMetricSpecs_mixedStandardAndNamed() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of("cpu_utilization", "named_metrics.foo", "named_metrics.bar")); + + assertThat(config.propagateCpuUtilization).isTrue(); + assertThat(config.propagateMemUtilization).isFalse(); + assertThat(config.shouldPropagateNamedMetric("foo")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("bar")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("baz")).isFalse(); + } + + @Test + public void fromMetricSpecs_wildcardAndSpecificNamedMetrics() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of("named_metrics.foo", "named_metrics.*")); + + assertThat(config.shouldPropagateNamedMetric("foo")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("bar")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("any_other_key")).isTrue(); + } + + @Test + public void fromMetricSpecs_malformedAndUnknownSpecs_areIgnored() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + asList( + "cpu_utilization", + null, // ignored + "disk_utilization", + "named_metrics.", // empty key + "named_metrics.valid" + )); + + assertThat(config.propagateCpuUtilization).isTrue(); + assertThat(config.propagateMemUtilization).isFalse(); + assertThat(config.shouldPropagateNamedMetric("disk_utilization")).isFalse(); + assertThat(config.shouldPropagateNamedMetric("valid")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("")).isFalse(); // from the empty key + } + + @Test + public void fromMetricSpecs_duplicateSpecs_areHandledGracefully() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of( + "cpu_utilization", + "named_metrics.foo", + "cpu_utilization", + "named_metrics.foo")); + + assertThat(config.propagateCpuUtilization).isTrue(); + assertThat(config.shouldPropagateNamedMetric("foo")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("bar")).isFalse(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/CommonBootstrapperTestUtils.java b/xds/src/test/java/io/grpc/xds/client/CommonBootstrapperTestUtils.java similarity index 57% rename from xds/src/test/java/io/grpc/xds/CommonBootstrapperTestUtils.java rename to xds/src/test/java/io/grpc/xds/client/CommonBootstrapperTestUtils.java index 0b2f3c7136b..e3760bd983f 100644 --- a/xds/src/test/java/io/grpc/xds/CommonBootstrapperTestUtils.java +++ b/xds/src/test/java/io/grpc/xds/client/CommonBootstrapperTestUtils.java @@ -14,21 +14,35 @@ * limitations under the License. */ -package io.grpc.xds; +package io.grpc.xds.client; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.grpc.ChannelCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.internal.BackoffPolicy; +import io.grpc.internal.FakeClock; import io.grpc.internal.JsonParser; -import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.Bootstrapper.ServerInfo; -import io.grpc.xds.client.EnvoyProtoData; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.TlsContextManagerImpl; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import javax.annotation.Nullable; public class CommonBootstrapperTestUtils { + public static final String SERVER_URI = "trafficdirector.googleapis.com"; + private static final ChannelCredentials CHANNEL_CREDENTIALS = InsecureChannelCredentials.create(); + private static final String SERVER_URI_CUSTOM_AUTHORITY = "trafficdirector2.googleapis.com"; + private static final String SERVER_URI_EMPTY_AUTHORITY = "trafficdirector3.googleapis.com"; + public static final String LDS_RESOURCE = "listener.googleapis.com"; + public static final String RDS_RESOURCE = "route-configuration.googleapis.com"; + public static final String CDS_RESOURCE = "cluster.googleapis.com"; + public static final String EDS_RESOURCE = "cluster-load-assignment.googleapis.com"; + private static final String FILE_WATCHER_CONFIG = "{\"path\": \"/etc/secret/certs\"}"; private static final String MESHCA_CONFIG = "{\n" @@ -88,7 +102,7 @@ public static Bootstrapper.BootstrapInfo buildBootstrapInfo( String certInstanceName1, @Nullable String privateKey1, @Nullable String cert1, @Nullable String trustCa1, String certInstanceName2, String privateKey2, String cert2, - String trustCa2) { + String trustCa2, @Nullable String spiffeTrustMap) { // get temp file for each file try { if (privateKey1 != null) { @@ -109,6 +123,9 @@ public static Bootstrapper.BootstrapInfo buildBootstrapInfo( if (trustCa2 != null) { trustCa2 = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(trustCa2); } + if (spiffeTrustMap != null) { + spiffeTrustMap = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(spiffeTrustMap); + } } catch (IOException ioe) { throw new RuntimeException(ioe); } @@ -116,6 +133,9 @@ public static Bootstrapper.BootstrapInfo buildBootstrapInfo( config.put("certificate_file", cert1); config.put("private_key_file", privateKey1); config.put("ca_certificate_file", trustCa1); + if (spiffeTrustMap != null) { + config.put("spiffe_trust_bundle_map_file", spiffeTrustMap); + } Bootstrapper.CertificateProviderInfo certificateProviderInfo = Bootstrapper.CertificateProviderInfo.create("file_watcher", config); HashMap certProviders = @@ -126,6 +146,9 @@ public static Bootstrapper.BootstrapInfo buildBootstrapInfo( config.put("certificate_file", cert2); config.put("private_key_file", privateKey2); config.put("ca_certificate_file", trustCa2); + if (spiffeTrustMap != null) { + config.put("spiffe_trust_bundle_map_file", spiffeTrustMap); + } certificateProviderInfo = Bootstrapper.CertificateProviderInfo.create("file_watcher", config); certProviders.put(certInstanceName2, certificateProviderInfo); @@ -136,4 +159,71 @@ public static Bootstrapper.BootstrapInfo buildBootstrapInfo( .certProviders(certProviders) .build(); } + + public static boolean setEnableXdsFallback(boolean target) { + boolean oldValue = BootstrapperImpl.enableXdsFallback; + BootstrapperImpl.enableXdsFallback = target; + return oldValue; + } + + public static XdsClientImpl createXdsClient(List serverUris, + XdsTransportFactory xdsTransportFactory, + FakeClock fakeClock, + BackoffPolicy.Provider backoffPolicyProvider, + MessagePrettyPrinter messagePrinter, + XdsClientMetricReporter xdsClientMetricReporter) { + return createXdsClient( + buildBootStrap(serverUris), + xdsTransportFactory, + fakeClock, + backoffPolicyProvider, + messagePrinter, + xdsClientMetricReporter); + } + + public static XdsClientImpl createXdsClient(Bootstrapper.BootstrapInfo bootstrapInfo, + XdsTransportFactory xdsTransportFactory, + FakeClock fakeClock, + BackoffPolicy.Provider backoffPolicyProvider, + MessagePrettyPrinter messagePrinter, + XdsClientMetricReporter xdsClientMetricReporter) { + return new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + fakeClock.getTimeProvider(), + messagePrinter, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + } + + public static Bootstrapper.BootstrapInfo buildBootStrap(List serverUris) { + + List serverInfos = new ArrayList<>(); + for (String uri : serverUris) { + serverInfos.add(ServerInfo.create(uri, CHANNEL_CREDENTIALS, false, true, false, false)); + } + EnvoyProtoData.Node node = EnvoyProtoData.Node.newBuilder().setId("node-id").build(); + + return Bootstrapper.BootstrapInfo.builder() + .servers(serverInfos) + .node(node) + .authorities(ImmutableMap.of( + "authority.xds.com", + Bootstrapper.AuthorityInfo.create( + "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_CUSTOM_AUTHORITY, CHANNEL_CREDENTIALS))), + "", + Bootstrapper.AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of("cert-instance-name", + Bootstrapper.CertificateProviderInfo.create("file-watcher", ImmutableMap.of()))) + .build(); + } + } diff --git a/xds/src/test/java/io/grpc/xds/client/LoadStatsManager2Test.java b/xds/src/test/java/io/grpc/xds/client/LoadStatsManager2Test.java index 9a90a92dcbd..a0642f7e4bb 100644 --- a/xds/src/test/java/io/grpc/xds/client/LoadStatsManager2Test.java +++ b/xds/src/test/java/io/grpc/xds/client/LoadStatsManager2Test.java @@ -27,6 +27,7 @@ import io.grpc.xds.client.Stats.ClusterStats; import io.grpc.xds.client.Stats.DroppedRequests; import io.grpc.xds.client.Stats.UpstreamLocalityStats; +import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.concurrent.TimeUnit; @@ -254,6 +255,59 @@ public void sharedLoadCounterStatsAggregation() { 2.718); } + @Test + public void recordMetrics_orcaLrsPropagationEnabled_specificMetrics() { + boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; + LoadStatsManager2.isEnabledOrcaLrsPropagation = true; + BackendMetricPropagation backendMetricPropagation = BackendMetricPropagation.fromMetricSpecs( + Arrays.asList("cpu_utilization", "named_metrics.named1")); + ClusterLocalityStats stats = loadStatsManager.getClusterLocalityStats( + CLUSTER_NAME1, EDS_SERVICE_NAME1, LOCALITY1, backendMetricPropagation); + + stats.recordTopLevelMetrics(0.8, 0.5, 0.0); // cpu, mem, app + stats.recordBackendLoadMetricStats(ImmutableMap.of("named1", 123.4, "named2", 567.8)); + stats.recordCallFinished(Status.OK); + ClusterStats report = Iterables.getOnlyElement( + loadStatsManager.getClusterStatsReports(CLUSTER_NAME1)); + UpstreamLocalityStats localityStats = + Iterables.getOnlyElement(report.upstreamLocalityStatsList()); + + assertThat(localityStats.loadMetricStatsMap()).containsKey("cpu_utilization"); + assertThat(localityStats.loadMetricStatsMap().get("cpu_utilization").totalMetricValue()) + .isWithin(TOLERANCE).of(0.8); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("mem_utilization"); + assertThat(localityStats.loadMetricStatsMap()).containsKey("named_metrics.named1"); + assertThat(localityStats.loadMetricStatsMap().get("named_metrics.named1").totalMetricValue()) + .isWithin(TOLERANCE).of(123.4); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("named_metrics.named2"); + LoadStatsManager2.isEnabledOrcaLrsPropagation = originalVal; + } + + @Test + public void recordMetrics_orcaLrsPropagationEnabled_wildcardNamedMetrics() { + boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; + LoadStatsManager2.isEnabledOrcaLrsPropagation = true; + BackendMetricPropagation backendMetricPropagation = BackendMetricPropagation.fromMetricSpecs( + Arrays.asList("named_metrics.*")); + ClusterLocalityStats stats = loadStatsManager.getClusterLocalityStats( + CLUSTER_NAME1, EDS_SERVICE_NAME1, LOCALITY1, backendMetricPropagation); + + stats.recordBackendLoadMetricStats(ImmutableMap.of("named1", 123.4, "named2", 567.8)); + stats.recordCallFinished(Status.OK); + ClusterStats report = Iterables.getOnlyElement( + loadStatsManager.getClusterStatsReports(CLUSTER_NAME1)); + UpstreamLocalityStats localityStats = + Iterables.getOnlyElement(report.upstreamLocalityStatsList()); + + assertThat(localityStats.loadMetricStatsMap()).containsKey("named_metrics.named1"); + assertThat(localityStats.loadMetricStatsMap().get("named_metrics.named1").totalMetricValue()) + .isWithin(TOLERANCE).of(123.4); + assertThat(localityStats.loadMetricStatsMap()).containsKey("named_metrics.named2"); + assertThat(localityStats.loadMetricStatsMap().get("named_metrics.named2").totalMetricValue()) + .isWithin(TOLERANCE).of(567.8); + LoadStatsManager2.isEnabledOrcaLrsPropagation = originalVal; + } + @Test public void loadCounterDelayedDeletionAfterAllInProgressRequestsReported() { ClusterLocalityStats counter = loadStatsManager.getClusterLocalityStats( diff --git a/xds/src/test/java/io/grpc/xds/internal/MatcherParserTest.java b/xds/src/test/java/io/grpc/xds/internal/MatcherParserTest.java new file mode 100644 index 00000000000..86a6a95fd4b --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/MatcherParserTest.java @@ -0,0 +1,85 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.envoyproxy.envoy.type.v3.FractionalPercent; +import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MatcherParserTest { + + @Test + public void parseStringMatcher_exact() { + StringMatcher proto = + StringMatcher.newBuilder().setExact("exact-match").setIgnoreCase(true).build(); + Matchers.StringMatcher matcher = MatcherParser.parseStringMatcher(proto); + assertThat(matcher).isNotNull(); + } + + @Test + public void parseStringMatcher_allTypes() { + MatcherParser.parseStringMatcher(StringMatcher.newBuilder().setExact("test").build()); + MatcherParser.parseStringMatcher(StringMatcher.newBuilder().setPrefix("test").build()); + MatcherParser.parseStringMatcher(StringMatcher.newBuilder().setSuffix("test").build()); + MatcherParser.parseStringMatcher(StringMatcher.newBuilder().setContains("test").build()); + MatcherParser.parseStringMatcher(StringMatcher.newBuilder() + .setSafeRegex(RegexMatcher.newBuilder().setRegex(".*").build()).build()); + } + + @Test + public void parseStringMatcher_unknownTypeThrows() { + StringMatcher unknownProto = StringMatcher.getDefaultInstance(); + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> MatcherParser.parseStringMatcher(unknownProto)); + assertThat(exception).hasMessageThat().contains("Unknown StringMatcher match pattern"); + } + + @Test + public void parseFractionMatcher_denominators() { + Matchers.FractionMatcher hundred = MatcherParser.parseFractionMatcher(FractionalPercent + .newBuilder().setNumerator(1).setDenominator(DenominatorType.HUNDRED).build()); + assertThat(hundred.numerator()).isEqualTo(1); + assertThat(hundred.denominator()).isEqualTo(100); + + Matchers.FractionMatcher tenThousand = MatcherParser.parseFractionMatcher(FractionalPercent + .newBuilder().setNumerator(2).setDenominator(DenominatorType.TEN_THOUSAND).build()); + assertThat(tenThousand.numerator()).isEqualTo(2); + assertThat(tenThousand.denominator()).isEqualTo(10_000); + + Matchers.FractionMatcher million = MatcherParser.parseFractionMatcher(FractionalPercent + .newBuilder().setNumerator(3).setDenominator(DenominatorType.MILLION).build()); + assertThat(million.numerator()).isEqualTo(3); + assertThat(million.denominator()).isEqualTo(1_000_000); + } + + @Test + public void parseFractionMatcher_unknownDenominatorThrows() { + FractionalPercent unknownProto = + FractionalPercent.newBuilder().setDenominatorValue(999).build(); + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> MatcherParser.parseFractionMatcher(unknownProto)); + assertThat(exception).hasMessageThat().contains("Unknown denominator type"); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java b/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java new file mode 100644 index 00000000000..bf5e0ae9ede --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java @@ -0,0 +1,98 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import io.grpc.services.InternalCallMetricRecorder; +import io.grpc.services.MetricReport; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.OptionalDouble; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link MetricReportUtils}. */ +@RunWith(JUnit4.class) +public class MetricReportUtilsTest { + + @Test + public void getMetric_cpuUtilization() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + OptionalDouble result = MetricReportUtils.getMetric(report, "cpu_utilization"); + assertTrue(result.isPresent()); + assertEquals(0.5, result.getAsDouble(), 0.0001); + } + + @Test + public void getMetric_applicationUtilization() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + OptionalDouble result = MetricReportUtils.getMetric(report, "application_utilization"); + assertTrue(result.isPresent()); + assertEquals(0.1, result.getAsDouble(), 0.0001); + } + + @Test + public void getMetric_memUtilization() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + OptionalDouble result = MetricReportUtils.getMetric(report, "mem_utilization"); + assertTrue(result.isPresent()); + assertEquals(0.2, result.getAsDouble(), 0.0001); + } + + @Test + public void getMetric_utilizationMetric() { + Map utilizationMetrics = new HashMap<>(); + utilizationMetrics.put("foo", 1.23); + MetricReport report = InternalCallMetricRecorder.createMetricReport( + 0, 0, 0, 0, 0, Collections.emptyMap(), utilizationMetrics, Collections.emptyMap()); + + OptionalDouble result = MetricReportUtils.getMetric(report, "utilization.foo"); + assertTrue(result.isPresent()); + assertEquals(1.23, result.getAsDouble(), 0.0001); + assertFalse(MetricReportUtils.getMetric(report, "utilization.bar").isPresent()); + } + + @Test + public void getMetric_namedMetric() { + Map namedMetrics = new HashMap<>(); + namedMetrics.put("foo", 7.89); + MetricReport report = createMetricReport(0, 0, 0, 0, 0, namedMetrics); + OptionalDouble result = MetricReportUtils.getMetric(report, "named_metrics.foo"); + assertTrue(result.isPresent()); + assertEquals(7.89, result.getAsDouble(), 0.0001); + + assertFalse(MetricReportUtils.getMetric(report, "named_metrics.bar").isPresent()); + } + + @Test + public void getMetric_unknownPrefix() { + MetricReport report = createMetricReport(0, 0, 0, 0, 0, Collections.emptyMap()); + assertFalse(MetricReportUtils.getMetric(report, "unknown.foo").isPresent()); + assertFalse(MetricReportUtils.getMetric(report, "foo").isPresent()); + } + + private MetricReport createMetricReport(double cpu, double app, double mem, double qps, + double eps, Map namedMetrics) { + return InternalCallMetricRecorder.createMetricReport( + cpu, app, mem, qps, eps, Collections.emptyMap(), Collections.emptyMap(), namedMetrics); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/ProtobufJsonConverterTest.java b/xds/src/test/java/io/grpc/xds/internal/ProtobufJsonConverterTest.java new file mode 100644 index 00000000000..86f9be4dda8 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/ProtobufJsonConverterTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; + +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ListValue; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ProtobufJsonConverterTest { + + @Test + public void testEmptyStruct() { + Struct emptyStruct = Struct.newBuilder().build(); + Map result = ProtobufJsonConverter.convertToJson(emptyStruct); + assertThat(result).isEmpty(); + } + + @Test + public void testStructWithValues() { + Struct struct = Struct.newBuilder() + .putFields("stringKey", Value.newBuilder().setStringValue("stringValue").build()) + .putFields("numberKey", Value.newBuilder().setNumberValue(123.45).build()) + .putFields("boolKey", Value.newBuilder().setBoolValue(true).build()) + .putFields("nullKey", Value.newBuilder().setNullValueValue(0).build()) + .putFields("structKey", Value.newBuilder() + .setStructValue(Struct.newBuilder() + .putFields("nestedKey", Value.newBuilder().setStringValue("nestedValue").build()) + .build()) + .build()) + .putFields("listKey", Value.newBuilder() + .setListValue(ListValue.newBuilder() + .addValues(Value.newBuilder().setNumberValue(1).build()) + .addValues(Value.newBuilder().setStringValue("two").build()) + .addValues(Value.newBuilder().setBoolValue(false).build()) + .build()) + .build()) + .build(); + + Map result = ProtobufJsonConverter.convertToJson(struct); + + Map goldenResult = new HashMap<>(); + goldenResult.put("stringKey", "stringValue"); + goldenResult.put("numberKey", 123.45); + goldenResult.put("boolKey", true); + goldenResult.put("nullKey", null); + goldenResult.put("structKey", ImmutableMap.of("nestedKey", "nestedValue")); + goldenResult.put("listKey", Arrays.asList(1.0, "two", false)); + + assertEquals(goldenResult, result); + } + + @Test(expected = IllegalArgumentException.class) + public void testUnknownValueType() { + Value unknownValue = Value.newBuilder().build(); // Default instance with no kind case set. + ProtobufJsonConverter.convertToJson( + Struct.newBuilder().putFields("unknownKey", unknownValue).build()); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/grpcservice/HeaderValueTest.java b/xds/src/test/java/io/grpc/xds/internal/grpcservice/HeaderValueTest.java new file mode 100644 index 00000000000..b55e6ae76f7 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/grpcservice/HeaderValueTest.java @@ -0,0 +1,49 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.protobuf.ByteString; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderValueTest { + + @Test + public void create_withStringValue_success() { + HeaderValue headerValue = HeaderValue.create("key1", "value1"); + assertThat(headerValue.key()).isEqualTo("key1"); + assertThat(headerValue.value().isPresent()).isTrue(); + assertThat(headerValue.value().get()).isEqualTo("value1"); + assertThat(headerValue.rawValue().isPresent()).isFalse(); + } + + @Test + public void create_withByteStringValue_success() { + ByteString rawValue = ByteString.copyFromUtf8("raw_value"); + HeaderValue headerValue = HeaderValue.create("key2", rawValue); + assertThat(headerValue.key()).isEqualTo("key2"); + assertThat(headerValue.rawValue().isPresent()).isTrue(); + assertThat(headerValue.rawValue().get()).isEqualTo(rawValue); + assertThat(headerValue.value().isPresent()).isFalse(); + } + + +} diff --git a/xds/src/test/java/io/grpc/xds/internal/grpcservice/HeaderValueValidationUtilsTest.java b/xds/src/test/java/io/grpc/xds/internal/grpcservice/HeaderValueValidationUtilsTest.java new file mode 100644 index 00000000000..c4658f3f305 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/grpcservice/HeaderValueValidationUtilsTest.java @@ -0,0 +1,87 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.protobuf.ByteString; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link HeaderValueValidationUtils}. + */ +@RunWith(JUnit4.class) +public class HeaderValueValidationUtilsTest { + + @Test + public void isDisallowed_string_emptyKey() { + assertThat(HeaderValueValidationUtils.isDisallowed("")).isTrue(); + } + + @Test + public void isDisallowed_string_tooLongKey() { + String longKey = new String(new char[16385]).replace('\0', 'a'); + assertThat(HeaderValueValidationUtils.isDisallowed(longKey)).isTrue(); + } + + @Test + public void isDisallowed_string_notLowercase() { + assertThat(HeaderValueValidationUtils.isDisallowed("Content-Type")).isTrue(); + } + + @Test + public void isDisallowed_string_grpcPrefix() { + assertThat(HeaderValueValidationUtils.isDisallowed("grpc-timeout")).isTrue(); + } + + @Test + public void isDisallowed_string_systemHeader_colon() { + assertThat(HeaderValueValidationUtils.isDisallowed(":authority")).isTrue(); + } + + @Test + public void isDisallowed_string_systemHeader_host() { + assertThat(HeaderValueValidationUtils.isDisallowed("host")).isTrue(); + } + + @Test + public void isDisallowed_string_valid() { + assertThat(HeaderValueValidationUtils.isDisallowed("content-type")).isFalse(); + } + + @Test + public void isDisallowed_headerValue_tooLongValue() { + String longValue = new String(new char[16385]).replace('\0', 'v'); + HeaderValue header = HeaderValue.create("content-type", longValue); + assertThat(HeaderValueValidationUtils.isDisallowed(header)).isTrue(); + } + + @Test + public void isDisallowed_headerValue_tooLongRawValue() { + ByteString longRawValue = ByteString.copyFrom(new byte[16385]); + HeaderValue header = HeaderValue.create("content-type", longRawValue); + assertThat(HeaderValueValidationUtils.isDisallowed(header)).isTrue(); + } + + @Test + public void isDisallowed_headerValue_valid() { + HeaderValue header = HeaderValue.create("content-type", "application/grpc"); + assertThat(HeaderValueValidationUtils.isDisallowed(header)).isFalse(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java new file mode 100644 index 00000000000..9f5cb75460f --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.re2j.Pattern; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutationRulesConfigTest { + @Test + public void testBuilderDefaultValues() { + HeaderMutationRulesConfig config = HeaderMutationRulesConfig.builder().build(); + assertFalse(config.disallowAll()); + assertFalse(config.disallowIsError()); + assertThat(config.allowExpression()).isEmpty(); + assertThat(config.disallowExpression()).isEmpty(); + } + + @Test + public void testBuilder_setDisallowAll() { + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().disallowAll(true).build(); + assertTrue(config.disallowAll()); + } + + @Test + public void testBuilder_setDisallowIsError() { + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().disallowIsError(true).build(); + assertTrue(config.disallowIsError()); + } + + @Test + public void testBuilder_setAllowExpression() { + Pattern pattern = Pattern.compile("allow.*"); + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().allowExpression(pattern).build(); + assertThat(config.allowExpression()).hasValue(pattern); + } + + @Test + public void testBuilder_setDisallowExpression() { + Pattern pattern = Pattern.compile("disallow.*"); + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().disallowExpression(pattern).build(); + assertThat(config.disallowExpression()).hasValue(pattern); + } + + @Test + public void testBuilder_setAll() { + Pattern allowPattern = Pattern.compile("allow.*"); + Pattern disallowPattern = Pattern.compile("disallow.*"); + HeaderMutationRulesConfig config = HeaderMutationRulesConfig.builder() + .disallowAll(true) + .disallowIsError(true) + .allowExpression(allowPattern) + .disallowExpression(disallowPattern) + .build(); + assertTrue(config.disallowAll()); + assertTrue(config.disallowIsError()); + assertThat(config.allowExpression()).hasValue(allowPattern); + assertThat(config.disallowExpression()).hasValue(disallowPattern); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParserTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParserTest.java new file mode 100644 index 00000000000..e880c197450 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParserTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.protobuf.BoolValue; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; +import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesParseException; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutationRulesParserTest { + + @Test + public void parse_protoWithAllFields_success() throws Exception { + HeaderMutationRules proto = HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("allow-.*")) + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("disallow-.*")) + .setDisallowAll(BoolValue.newBuilder().setValue(true).build()) + .setDisallowIsError(BoolValue.newBuilder().setValue(true).build()) + .build(); + + HeaderMutationRulesConfig config = HeaderMutationRulesParser.parse(proto); + + assertThat(config.allowExpression().isPresent()).isTrue(); + assertThat(config.allowExpression().get().pattern()).isEqualTo("allow-.*"); + + assertThat(config.disallowExpression().isPresent()).isTrue(); + assertThat(config.disallowExpression().get().pattern()).isEqualTo("disallow-.*"); + + assertThat(config.disallowAll()).isTrue(); + assertThat(config.disallowIsError()).isTrue(); + } + + @Test + public void parse_protoWithNoExpressions_success() throws Exception { + HeaderMutationRules proto = HeaderMutationRules.newBuilder().build(); + + HeaderMutationRulesConfig config = HeaderMutationRulesParser.parse(proto); + + assertThat(config.allowExpression().isPresent()).isFalse(); + assertThat(config.disallowExpression().isPresent()).isFalse(); + assertThat(config.disallowAll()).isFalse(); + assertThat(config.disallowIsError()).isFalse(); + } + + @Test + public void parse_invalidRegexAllowExpression_throwsHeaderMutationRulesParseException() { + HeaderMutationRules proto = HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("allow-[")) + .build(); + + HeaderMutationRulesParseException exception = assertThrows( + HeaderMutationRulesParseException.class, () -> HeaderMutationRulesParser.parse(proto)); + + assertThat(exception).hasMessageThat().contains("Invalid regex pattern for allow_expression"); + } + + @Test + public void parse_invalidRegexDisallowExpression_throwsHeaderMutationRulesParseException() { + HeaderMutationRules proto = HeaderMutationRules.newBuilder() + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("disallow-[")) + .build(); + + HeaderMutationRulesParseException exception = assertThrows( + HeaderMutationRulesParseException.class, () -> HeaderMutationRulesParser.parse(proto)); + + assertThat(exception).hasMessageThat() + .contains("Invalid regex pattern for disallow_expression"); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java index 4de881c710e..a0eac581d5c 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java @@ -28,18 +28,17 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; -import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProviderFactory; import io.grpc.xds.internal.security.certprovider.CertificateProvider; import io.grpc.xds.internal.security.certprovider.CertificateProviderProvider; import io.grpc.xds.internal.security.certprovider.CertificateProviderRegistry; import io.grpc.xds.internal.security.certprovider.CertificateProviderStore; +import io.grpc.xds.internal.security.certprovider.IgnoreUpdatesWatcher; import io.grpc.xds.internal.security.certprovider.TestCertificateProvider; -import java.io.IOException; -import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -86,7 +85,7 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); // verify that bootstrapInfo is cached... sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); @@ -121,7 +120,7 @@ public void bothPresent_expectCertProviderClientSslContextProvider() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -147,7 +146,7 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -181,7 +180,7 @@ public void createCertProviderClientSslContextProvider_withStaticContext() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -211,8 +210,8 @@ public void createCertProviderClientSslContextProvider_2providers() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -248,8 +247,8 @@ public void createNewCertProviderClientSslContextProvider_withSans() { clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -282,23 +281,7 @@ public void createNewCertProviderClientSslContextProvider_onlyRootCert() { clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - } - - @Test - public void createNullCommonTlsContext_exception() throws IOException { - clientSslContextProviderFactory = - new ClientSslContextProviderFactory( - null, certProviderClientSslContextProviderFactory); - UpstreamTlsContext upstreamTlsContext = new UpstreamTlsContext(null); - try { - clientSslContextProviderFactory.create(upstreamTlsContext); - Assert.fail("no exception thrown"); - } catch (NullPointerException expected) { - assertThat(expected) - .hasMessageThat() - .isEqualTo("upstreamTlsContext should have CommonTlsContext"); - } + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } static void createAndRegisterProviderProvider( @@ -328,14 +311,20 @@ public CertificateProvider answer(InvocationOnMock invocation) throws Throwable } static void verifyWatcher( - SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor) { + SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor, + boolean usesDelegateWatcher) { assertThat(watcherCaptor).isNotNull(); assertThat(watcherCaptor.getDownstreamWatchers()).hasSize(1); - assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) - .isSameInstanceAs(sslContextProvider); + if (usesDelegateWatcher) { + assertThat(((IgnoreUpdatesWatcher) watcherCaptor.getDownstreamWatchers().iterator().next()) + .getDelegate()) + .isSameInstanceAs(sslContextProvider); + } else { + assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) + .isSameInstanceAs(sslContextProvider); + } } - @SuppressWarnings("deprecation") static CommonTlsContext.Builder addFilenames( CommonTlsContext.Builder builder, String certChain, String privateKey, String trustCa) { TlsCertificate tlsCert = @@ -347,13 +336,10 @@ static CommonTlsContext.Builder addFilenames( CertificateValidationContext.newBuilder() .setTrustedCa(DataSource.newBuilder().setFilename(trustCa)) .build(); - CommonTlsContext.CertificateProviderInstance certificateProviderInstance = - builder.getValidationContextCertificateProviderInstance(); CommonTlsContext.CombinedCertificateValidationContext.Builder combinedBuilder = CommonTlsContext.CombinedCertificateValidationContext.newBuilder(); combinedBuilder - .setDefaultValidationContext(certContext) - .setValidationContextCertificateProviderInstance(certificateProviderInstance); + .setDefaultValidationContext(certContext); return builder .addTlsCertificates(tlsCert) .setCombinedValidationContext(combinedBuilder.build()); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 8a04a3d02a7..abacd2038f8 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -23,7 +23,6 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; @@ -37,10 +36,12 @@ import java.io.InputStream; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.AbstractMap; import java.util.Arrays; import java.util.List; import java.util.concurrent.Executor; import javax.annotation.Nullable; +import javax.net.ssl.X509TrustManager; /** Utility class for client and server ssl provider tests. */ public class CommonTlsContextTestsUtil { @@ -48,59 +49,43 @@ public class CommonTlsContextTestsUtil { public static final String SERVER_0_PEM_FILE = "server0.pem"; public static final String SERVER_0_KEY_FILE = "server0.key"; public static final String SERVER_1_PEM_FILE = "server1.pem"; + public static final String SERVER_1_SPIFFE_PEM_FILE = "server1_spiffe.pem"; public static final String SERVER_1_KEY_FILE = "server1.key"; public static final String CLIENT_PEM_FILE = "client.pem"; + public static final String CLIENT_SPIFFE_PEM_FILE = "client_spiffe.pem"; public static final String CLIENT_KEY_FILE = "client.key"; public static final String CA_PEM_FILE = "ca.pem"; + public static final String SPIFFE_TRUST_MAP_FILE = "spiffebundle.json"; + public static final String SPIFFE_TRUST_MAP_1_FILE = "spiffebundle1.json"; /** Bad/untrusted server certs. */ public static final String BAD_SERVER_PEM_FILE = "badserver.pem"; public static final String BAD_SERVER_KEY_FILE = "badserver.key"; public static final String BAD_CLIENT_PEM_FILE = "badclient.pem"; public static final String BAD_CLIENT_KEY_FILE = "badclient.key"; + public static final String BAD_WILDCARD_DNS_PEM_FILE = + "sni-test-certs/bad_wildcard_dns_certificate.pem"; /** takes additional values and creates CombinedCertificateValidationContext as needed. */ - @SuppressWarnings("deprecation") - static CommonTlsContext buildCommonTlsContextWithAdditionalValues( + private static CommonTlsContext buildCommonTlsContextWithAdditionalValues( String certInstanceName, String certName, String validationContextCertInstanceName, String validationContextCertName, Iterable matchSubjectAltNames, Iterable alpnNames) { - - CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); - - CertificateProviderInstance certificateProviderInstance = CertificateProviderInstance - .newBuilder().setInstanceName(certInstanceName).setCertificateName(certName).build(); - if (certificateProviderInstance != null) { - builder.setTlsCertificateCertificateProviderInstance(certificateProviderInstance); - } - CertificateProviderInstance validationCertificateProviderInstance = - CertificateProviderInstance.newBuilder().setInstanceName(validationContextCertInstanceName) - .setCertificateName(validationContextCertName).build(); - CertificateValidationContext certValidationContext = - matchSubjectAltNames == null - ? null - : CertificateValidationContext.newBuilder() - .addAllMatchSubjectAltNames(matchSubjectAltNames) - .build(); - if (validationCertificateProviderInstance != null) { - CombinedCertificateValidationContext.Builder combinedBuilder = - CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - validationCertificateProviderInstance); - if (certValidationContext != null) { - combinedBuilder = combinedBuilder.setDefaultValidationContext(certValidationContext); - } - builder.setCombinedValidationContext(combinedBuilder); - } else if (validationCertificateProviderInstance != null) { - builder - .setValidationContextCertificateProviderInstance(validationCertificateProviderInstance); - } else if (certValidationContext != null) { - builder.setValidationContext(certValidationContext); - } - if (alpnNames != null) { - builder.addAllAlpnProtocols(alpnNames); - } - return builder.build(); + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names + CertificateValidationContext.Builder certificateValidationContextBuilder + = CertificateValidationContext.newBuilder() + .addAllMatchSubjectAltNames(matchSubjectAltNames); + return CommonTlsContext.newBuilder() + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.newBuilder() + .setInstanceName(certInstanceName) + .setCertificateName(certName)) + .setCombinedValidationContext(CombinedCertificateValidationContext.newBuilder() + .setDefaultValidationContext(certificateValidationContextBuilder + .setCaCertificateProviderInstance(CertificateProviderPluginInstance.newBuilder() + .setInstanceName(validationContextCertInstanceName) + .setCertificateName(validationContextCertName)))) + .addAllAlpnProtocols(alpnNames) + .build(); } /** Helper method to build DownstreamTlsContext for multiple test classes. */ @@ -148,7 +133,7 @@ public static DownstreamTlsContext buildTestDownstreamTlsContext( useSans ? Arrays.asList( StringMatcher.newBuilder() .setExact("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob") - .build()) : null, + .build()) : Arrays.asList(), Arrays.asList("managed-tls")); } return buildDownstreamTlsContext(commonTlsContext, /* requireClientCert= */ false); @@ -168,11 +153,24 @@ public static String getTempFileNameForResourcesFile(String resFile) throws IOEx * Helper method to build UpstreamTlsContext for above tests. Called from other classes as well. */ static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( - CommonTlsContext commonTlsContext) { - UpstreamTlsContext upstreamTlsContext = - UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build(); + CommonTlsContext commonTlsContext) { + return buildUpstreamTlsContext(commonTlsContext, "", false, false); + } + + /** + * Helper method to build UpstreamTlsContext with SNI info. + */ + static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( + CommonTlsContext commonTlsContext, String sni, boolean autoHostSni, + boolean autoSniSanValidation) { + UpstreamTlsContext.Builder upstreamTlsContext = + UpstreamTlsContext.newBuilder() + .setCommonTlsContext(commonTlsContext) + .setAutoHostSni(autoHostSni) + .setAutoSniSanValidation(autoSniSanValidation) + .setSni(sni); return EnvoyServerProtoData.UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext( - upstreamTlsContext); + upstreamTlsContext.build()); } /** Helper method to build UpstreamTlsContext for multiple test classes. */ @@ -187,6 +185,21 @@ public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( null); } + /** Helper method to build UpstreamTlsContext with SNI info. */ + public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( + String commonInstanceName, boolean hasIdentityCert, String sni, boolean autoHostSni) { + return buildUpstreamTlsContextForCertProviderInstance( + hasIdentityCert ? commonInstanceName : null, + hasIdentityCert ? "default" : null, + commonInstanceName, + "ROOT", + null, + null, + sni, + autoHostSni, + false); + } + /** Gets a cert from contents of a resource. */ public static X509Certificate getCertFromResourceName(String resourceName) throws IOException, CertificateException { @@ -195,7 +208,6 @@ public static X509Certificate getCertFromResourceName(String resourceName) } } - @SuppressWarnings("deprecation") private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( String certInstanceName, String certName, @@ -206,10 +218,37 @@ private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); if (certInstanceName != null) { builder = - builder.setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.newBuilder() - .setInstanceName(certInstanceName) - .setCertificateName(certName)); + builder.setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder() + .setInstanceName(certInstanceName) + .setCertificateName(certName)); + } + builder = + addCertificateValidationContext( + builder, rootInstanceName, rootCertName, staticCertValidationContext); + if (alpnProtocols != null) { + builder.addAllAlpnProtocols(alpnProtocols); + } + return builder.build(); + } + + /** Helper method to build CommonTlsContext using deprecated certificate provider field. */ + @SuppressWarnings("deprecation") + public static CommonTlsContext buildCommonTlsContextWithDeprecatedCertProviderInstance( + String certInstanceName, + String certName, + String rootInstanceName, + String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); + if (certInstanceName != null) { + // Use deprecated field (field 11) instead of current field (field 14) + builder = + builder.setTlsCertificateCertificateProviderInstance( + CommonTlsContext.CertificateProviderInstance.newBuilder() + .setInstanceName(certInstanceName) + .setCertificateName(certName)); } builder = addCertificateValidationContext( @@ -244,29 +283,28 @@ private static CommonTlsContext buildNewCommonTlsContextForCertProviderInstance( return builder.build(); } - @SuppressWarnings("deprecation") private static CommonTlsContext.Builder addCertificateValidationContext( CommonTlsContext.Builder builder, String rootInstanceName, String rootCertName, CertificateValidationContext staticCertValidationContext) { + if (staticCertValidationContext == null && rootInstanceName == null) { + return builder; + } + CertificateValidationContext.Builder contextBuilder; + if (staticCertValidationContext == null) { + contextBuilder = CertificateValidationContext.newBuilder(); + } else { + contextBuilder = staticCertValidationContext.toBuilder(); + } if (rootInstanceName != null) { - CertificateProviderInstance providerInstance = - CertificateProviderInstance.newBuilder() - .setInstanceName(rootInstanceName) - .setCertificateName(rootCertName) - .build(); - if (staticCertValidationContext != null) { - CombinedCertificateValidationContext combined = - CombinedCertificateValidationContext.newBuilder() - .setDefaultValidationContext(staticCertValidationContext) - .setValidationContextCertificateProviderInstance(providerInstance) - .build(); - return builder.setCombinedValidationContext(combined); - } - builder = builder.setValidationContextCertificateProviderInstance(providerInstance); + contextBuilder.setCaCertificateProviderInstance(CertificateProviderPluginInstance.newBuilder() + .setInstanceName(rootInstanceName) + .setCertificateName(rootCertName)); + builder.setValidationContext(contextBuilder.build()); } - return builder; + return builder.setCombinedValidationContext(CombinedCertificateValidationContext.newBuilder() + .setDefaultValidationContext(contextBuilder)); } private static CommonTlsContext.Builder addNewCertificateValidationContext( @@ -274,19 +312,19 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( String rootInstanceName, String rootCertName, CertificateValidationContext staticCertValidationContext) { + CertificateValidationContext.Builder validationContextBuilder = + staticCertValidationContext != null ? staticCertValidationContext.toBuilder() + : CertificateValidationContext.newBuilder(); if (rootInstanceName != null) { CertificateProviderPluginInstance providerInstance = CertificateProviderPluginInstance.newBuilder() .setInstanceName(rootInstanceName) .setCertificateName(rootCertName) .build(); - CertificateValidationContext.Builder validationContextBuilder = - staticCertValidationContext != null ? staticCertValidationContext.toBuilder() - : CertificateValidationContext.newBuilder(); - return builder.setValidationContext( - validationContextBuilder.setCaCertificateProviderInstance(providerInstance)); + validationContextBuilder = validationContextBuilder.setCaCertificateProviderInstance( + providerInstance); } - return builder; + return builder.setValidationContext(validationContextBuilder); } /** Helper method to build UpstreamTlsContext for CertProvider tests. */ @@ -305,7 +343,31 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext)); + staticCertValidationContext), + "", false, false); + } + + /** Helper method to build UpstreamTlsContext with SNI info for CertProvider tests. */ + public static EnvoyServerProtoData.UpstreamTlsContext + buildUpstreamTlsContextForCertProviderInstance( + @Nullable String certInstanceName, + @Nullable String certName, + @Nullable String rootInstanceName, + @Nullable String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext, + String sni, + boolean autoHostSni, + boolean autoSniSanValidation) { + return buildUpstreamTlsContext( + buildCommonTlsContextForCertProviderInstance( + certInstanceName, + certName, + rootInstanceName, + rootCertName, + alpnProtocols, + staticCertValidationContext), + sni, autoHostSni, autoSniSanValidation); } /** Helper method to build UpstreamTlsContext for CertProvider tests. */ @@ -324,7 +386,8 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext)); + staticCertValidationContext), + "", false, false); } /** Helper method to build DownstreamTlsContext for CertProvider tests. */ @@ -368,14 +431,15 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( } /** Perform some simple checks on sslContext. */ - public static void doChecksOnSslContext(boolean server, SslContext sslContext, + public static void doChecksOnSslContext(boolean server, + AbstractMap.SimpleImmutableEntry sslContextAndTm, List expectedApnProtos) { if (server) { - assertThat(sslContext.isServer()).isTrue(); + assertThat(sslContextAndTm.getKey().isServer()).isTrue(); } else { - assertThat(sslContext.isClient()).isTrue(); + assertThat(sslContextAndTm.getKey().isClient()).isTrue(); } - List apnProtos = sslContext.applicationProtocolNegotiator().protocols(); + List apnProtos = sslContextAndTm.getKey().applicationProtocolNegotiator().protocols(); assertThat(apnProtos).isNotNull(); if (expectedApnProtos != null) { assertThat(apnProtos).isEqualTo(expectedApnProtos); @@ -401,7 +465,7 @@ public static TestCallback getValueThruCallback(SslContextProvider provider, Exe public static class TestCallback extends SslContextProvider.Callback { - public SslContext updatedSslContext; + public AbstractMap.SimpleImmutableEntry updatedSslContext; public Throwable updatedThrowable; public TestCallback(Executor executor) { @@ -409,7 +473,8 @@ public TestCallback(Executor executor) { } @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContext) { updatedSslContext = sslContext; } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index da7f8113dfa..f11c661e211 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -45,12 +45,12 @@ import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.ProtocolNegotiationEvent; -import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.InternalXdsAttributes; import io.grpc.xds.TlsContextManager; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; +import io.grpc.xds.internal.XdsInternalAttributes; import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSecurityHandler; import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSecurityProtocolNegotiator; import io.grpc.xds.internal.security.certprovider.CommonCertProviderTestUtils; @@ -74,11 +74,13 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.cert.CertStoreException; +import java.util.AbstractMap; import java.util.Iterator; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import javax.net.ssl.X509TrustManager; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -87,6 +89,10 @@ @RunWith(JUnit4.class) public class SecurityProtocolNegotiatorsTest { + private static final String HOSTNAME = "hostname"; + private static final String SNI_IN_UTC = "sni-in-upstream-tls-context"; + private static final String FAKE_AUTHORITY = "authority"; + private final GrpcHttp2ConnectionHandler grpcHandler = FakeGrpcHttp2ConnectionHandler.newHandler(); @@ -122,8 +128,31 @@ public void clientSecurityProtocolNegotiatorNewHandler_noFallback_expectExceptio @Test public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() { + UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext( + CommonTlsContext.newBuilder().build()); + ClientSecurityProtocolNegotiator pn = + new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); + GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); + ChannelLogger logger = mock(ChannelLogger.class); + doNothing().when(logger).log(any(ChannelLogLevel.class), anyString()); + when(mockHandler.getNegotiationLogger()).thenReturn(logger); + TlsContextManager mockTlsContextManager = mock(TlsContextManager.class); + when(mockHandler.getEagAttributes()) + .thenReturn( + Attributes.newBuilder() + .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager)) + .build()); + ChannelHandler newHandler = pn.newHandler(mockHandler); + assertThat(newHandler).isNotNull(); + assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); + } + + @Test + public void clientSecurityProtocolNegotiator_autoHostSni_hostnamePassedToClientSecurityHandlr() { UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build()); + CommonTlsContextTestsUtil.buildUpstreamTlsContext( + CommonTlsContext.newBuilder().build(), "", true, false); ClientSecurityProtocolNegotiator pn = new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); @@ -134,12 +163,14 @@ public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() when(mockHandler.getEagAttributes()) .thenReturn( Attributes.newBuilder() - .set(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager)) + .set(XdsInternalAttributes.ATTR_ADDRESS_NAME, FAKE_AUTHORITY) .build()); ChannelHandler newHandler = pn.newHandler(mockHandler); assertThat(newHandler).isNotNull(); assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); + assertThat(((ClientSecurityHandler) newHandler).getSni()).isEqualTo(FAKE_AUTHORITY); } @Test @@ -149,7 +180,7 @@ public void clientSecurityHandler_addLast() CommonCertProviderTestUtils.register(executor); Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null); + CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); @@ -158,7 +189,7 @@ public void clientSecurityHandler_addLast() new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier); + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); pipeline.addLast(clientSecurityHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler); assertNotNull(channelHandlerCtx); @@ -169,19 +200,20 @@ public void clientSecurityHandler_addLast() sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { - future.set(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + future.set(sslContextAndTm); } @Override protected void onException(Throwable throwable) { future.set(throwable); } - }); + }, false); assertThat(executor.runDueTasks()).isEqualTo(1); channel.runPendingTasks(); Object fromFuture = future.get(2, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); + assertThat(fromFuture).isInstanceOf(AbstractMap.SimpleImmutableEntry.class); channel.runPendingTasks(); channelHandlerCtx = pipeline.context(clientSecurityHandler); assertThat(channelHandlerCtx).isNull(); @@ -195,6 +227,75 @@ protected void onException(Throwable throwable) { CommonCertProviderTestUtils.register0(); } + @Test + public void sniInClientSecurityHandler_autoHostSniIsTrue_usesEndpointHostname() { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, + CLIENT_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, "", true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(HOSTNAME); + } + + @Test + public void sniInClientSecurityHandler_autoHostSni_endpointHostnameIsEmpty_usesSniFromUtc() { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, + CLIENT_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe-client", true, SNI_IN_UTC, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, ""); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } + + @Test + public void sniInClientSecurityHandler_autoHostSni_endpointHostnameIsNull_usesSniFromUtc() { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, + CLIENT_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe-client", true, SNI_IN_UTC, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, null); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } + + @Test + public void sniInClientSecurityHandler_autoHostSniIsFalse_usesSniFromUpstreamTlsContext() { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, + CLIENT_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe-client", true, SNI_IN_UTC, false); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } + @Test public void serverSecurityHandler_addLast() throws InterruptedException, TimeoutException, ExecutionException { @@ -216,7 +317,7 @@ public SocketAddress remoteAddress() { pipeline = channel.pipeline(); Bootstrapper.BootstrapInfo bootstrapInfoForServer = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE, - SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null); + SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContext( "google_cloud_private_spiffe-server", true, true); @@ -246,19 +347,20 @@ public SocketAddress remoteAddress() { sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { - future.set(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + future.set(sslContextAndTm); } @Override protected void onException(Throwable throwable) { future.set(throwable); } - }); + }, false); channel.runPendingTasks(); // need this for tasks to execute on eventLoop assertThat(executor.runDueTasks()).isEqualTo(1); Object fromFuture = future.get(2, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); + assertThat(fromFuture).isInstanceOf(AbstractMap.SimpleImmutableEntry.class); channel.runPendingTasks(); channelHandlerCtx = pipeline.context(SecurityProtocolNegotiators.ServerSecurityHandler.class); assertThat(channelHandlerCtx).isNull(); @@ -356,12 +458,12 @@ public void nullTlsContext_nullFallbackProtocolNegotiator_expectException() { @Test public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() - throws InterruptedException, TimeoutException, ExecutionException { + throws InterruptedException, TimeoutException, ExecutionException { FakeClock executor = new FakeClock(); CommonCertProviderTestUtils.register(executor); Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils - .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null); + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, + CLIENT_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); @@ -370,7 +472,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEv new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier); + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); pipeline.addLast(clientSecurityHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler); @@ -382,19 +484,20 @@ public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEv sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { - future.set(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + future.set(sslContextAndTm); } @Override protected void onException(Throwable throwable) { future.set(throwable); } - }); + }, false); executor.runDueTasks(); channel.runPendingTasks(); // need this for tasks to execute on eventLoop Object fromFuture = future.get(5, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); + assertThat(fromFuture).isInstanceOf(AbstractMap.SimpleImmutableEntry.class); channel.runPendingTasks(); channelHandlerCtx = pipeline.context(clientSecurityHandler); assertThat(channelHandlerCtx).isNull(); @@ -412,7 +515,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_handleHandlerRemoved() { CommonCertProviderTestUtils.register(executor); Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null); + CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); @@ -421,7 +524,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_handleHandlerRemoved() { new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier); + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); pipeline.addLast(clientSecurityHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler); @@ -459,7 +562,7 @@ static FakeGrpcHttp2ConnectionHandler newHandler() { @Override public String getAuthority() { - return "authority"; + return FAKE_AUTHORITY; } } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java index c455385dae9..7a5a6c00639 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java @@ -24,10 +24,10 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; -import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.internal.security.certprovider.CertProviderServerSslContextProviderFactory; import io.grpc.xds.internal.security.certprovider.CertificateProvider; @@ -78,7 +78,7 @@ public void createCertProviderServerSslContextProvider() throws XdsInitializatio serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); // verify that bootstrapInfo is cached... sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); @@ -117,7 +117,7 @@ public void bothPresent_expectCertProviderServerSslContextProvider() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -144,7 +144,7 @@ public void createCertProviderServerSslContextProvider_onlyCertInstance() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -179,7 +179,7 @@ public void createCertProviderServerSslContextProvider_withStaticContext() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); } @Test @@ -210,8 +210,8 @@ public void createCertProviderServerSslContextProvider_2providers() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -249,7 +249,7 @@ public void createNewCertProviderServerSslContextProvider_withSans() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index f476818297d..70a53c53205 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -17,8 +17,9 @@ package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.buildUpstreamTlsContext; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -26,10 +27,13 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; +import java.util.AbstractMap; import java.util.concurrent.Executor; +import javax.net.ssl.X509TrustManager; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -47,14 +51,14 @@ public class SslContextProviderSupplierTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Mock private TlsContextManager mockTlsContextManager; + @Mock private Executor mockExecutor; private SslContextProviderSupplier supplier; private SslContextProvider mockSslContextProvider; - private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; + private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = + buildUpstreamTlsContext("google_cloud_private_spiffe", true); private SslContextProvider.Callback mockCallback; private void prepareSupplier() { - upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); mockSslContextProvider = mock(SslContextProvider.class); doReturn(mockSslContextProvider) .when(mockTlsContextManager) @@ -64,9 +68,8 @@ private void prepareSupplier() { private void callUpdateSslContext() { mockCallback = mock(SslContextProvider.Callback.class); - Executor mockExecutor = mock(Executor.class); doReturn(mockExecutor).when(mockCallback).getExecutor(); - supplier.updateSslContext(mockCallback); + supplier.updateSslContext(mockCallback, false); } @Test @@ -82,26 +85,57 @@ public void get_updateSecret() { verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); - SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); + @SuppressWarnings("unchecked") + AbstractMap.SimpleImmutableEntry mockSslContextAndTm = + mock(AbstractMap.SimpleImmutableEntry.class); + capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); + verify(mockCallback, times(1)) + .updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - supplier.updateSslContext(mockCallback); + supplier.updateSslContext(mockCallback, false); verify(mockTlsContextManager, times(3)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); } @Test - public void get_onException() { + public void autoHostSniFalse_usesSniFromUpstreamTlsContext() { prepareSupplier(); callUpdateSslContext(); + verify(mockTlsContextManager, times(2)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + verify(mockTlsContextManager, times(0)) + .releaseClientSslContextProvider(any(SslContextProvider.class)); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); + @SuppressWarnings("unchecked") + AbstractMap.SimpleImmutableEntry mockSslContextAndTm = + mock(AbstractMap.SimpleImmutableEntry.class); + capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); + verify(mockCallback, times(1)) + .updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(eq(mockSslContextProvider)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + supplier.updateSslContext(mockCallback, false); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + } + + @Test + public void get_onException() { + prepareSupplier(); + callUpdateSslContext(); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); + verify(mockSslContextProvider, times(1)) + .addCallback(callbackCaptor.capture()); + SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); + assertThat(capturedCallback).isNotNull(); Exception exception = new Exception("test"); capturedCallback.onException(exception); verify(mockCallback, times(1)).onException(eq(exception)); @@ -109,6 +143,46 @@ public void get_onException() { .releaseClientSslContextProvider(eq(mockSslContextProvider)); } + @Test + public void systemRootCertsWithMtls_callbackExecutedFromProvider() { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + null, + "root-default", + null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build()); + prepareSupplier(); + + callUpdateSslContext(); + + verify(mockTlsContextManager, times(2)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + verify(mockTlsContextManager, times(0)) + .releaseClientSslContextProvider(any(SslContextProvider.class)); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); + verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); + SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); + assertThat(capturedCallback).isNotNull(); + @SuppressWarnings("unchecked") + AbstractMap.SimpleImmutableEntry mockSslContextAndTm = + mock(AbstractMap.SimpleImmutableEntry.class); + capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); + verify(mockCallback, times(1)) + .updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(eq(mockSslContextProvider)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + supplier.updateSslContext(mockCallback, false); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + } + @Test public void testClose() { prepareSupplier(); @@ -116,7 +190,7 @@ public void testClose() { supplier.close(); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider)); - supplier.updateSslContext(mockCallback); + supplier.updateSslContext(mockCallback, false); verify(mockTlsContextManager, times(3)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); verify(mockTlsContextManager, times(1)) diff --git a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java index 4d04eeb41e0..035096a3528 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java @@ -30,10 +30,10 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; import org.junit.Rule; import org.junit.Test; @@ -57,7 +57,7 @@ public class TlsContextManagerTest { public void createServerSslContextProvider() { Bootstrapper.BootstrapInfo bootstrapInfoForServer = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE, - SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null); + SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContext( "google_cloud_private_spiffe-server", false, false); @@ -76,7 +76,7 @@ public void createServerSslContextProvider() { public void createClientSslContextProvider() { Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null); + CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false); @@ -96,7 +96,7 @@ public void createServerSslContextProvider_differentInstance() { Bootstrapper.BootstrapInfo bootstrapInfoForServer = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE, "cert-instance2", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE, - CA_PEM_FILE); + CA_PEM_FILE, null); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContext( "google_cloud_private_spiffe-server", false, false); @@ -120,7 +120,7 @@ public void createServerSslContextProvider_differentInstance() { public void createClientSslContextProvider_differentInstance() { Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, "cert-instance-2", CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); + CA_PEM_FILE, "cert-instance-2", CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index 5925c5f03b1..91f02863ca4 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -33,9 +33,9 @@ import com.google.common.util.concurrent.MoreExecutors; import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; -import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil.TestCallback; import java.util.Queue; @@ -72,15 +72,28 @@ private CertProviderClientSslContextProvider getSslContextProvider( String rootInstanceName, Bootstrapper.BootstrapInfo bootstrapInfo, Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext) { - EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( - certInstanceName, - "cert-default", - rootInstanceName, - "root-default", - alpnProtocols, - staticCertValidationContext); + CertificateValidationContext staticCertValidationContext, + boolean useSystemRootCerts) { + EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; + if (useSystemRootCerts) { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + } else { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + } return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, @@ -122,12 +135,12 @@ public void testProviderForClient_mtls() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -135,11 +148,11 @@ public void testProviderForClient_mtls() throws Exception { ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -168,11 +181,92 @@ public void testProviderForClient_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); + } + + @Test + public void testProviderForClient_systemRootCerts_mtls() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + "gcp_id", + null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(), + true); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); + + // now generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + TestCallback testCallback1 = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // now update id cert: sslContext should be updated i.e. different from the previous one + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } + @Test + public void testProviderForClient_systemRootCerts_regularTls() { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + null, + null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(), + true); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback.updatedSslContext).isEqualTo(provider.getSslContextAndTrustManager()); + + assertThat(watcherCaptor[0]).isNull(); + } + @Test public void testProviderForClient_mtls_newXds() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = @@ -190,7 +284,7 @@ public void testProviderForClient_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -198,11 +292,11 @@ public void testProviderForClient_mtls_newXds() throws Exception { ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -231,7 +325,7 @@ public void testProviderForClient_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -248,7 +342,7 @@ public void testProviderForClient_queueExecutor() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); QueuedExecutor queuedExecutor = new QueuedExecutor(); TestCallback testCallback = @@ -281,16 +375,16 @@ public void testProviderForClient_tls() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -318,7 +412,7 @@ public void testProviderForClient_sslContextException_onError() throws Exception "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */null, - staticCertValidationContext); + staticCertValidationContext, false); TestCallback testCallback = new TestCallback(MoreExecutors.directExecutor()); provider.addCallback(testCallback); @@ -338,7 +432,8 @@ public void testProviderForClient_sslContextException_onError() throws Exception } @Test - public void testProviderForClient_rootInstanceNull_expectError() throws Exception { + public void testProviderForClient_rootInstanceNull_and_notUsingSystemRootCerts_expectError() + throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = new CertificateProvider.DistributorWatcher[1]; TestCertificateProvider.createAndRegisterProviderProvider( @@ -349,13 +444,84 @@ public void testProviderForClient_rootInstanceNull_expectError() throws Exceptio /* rootInstanceName= */ null, CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); fail("exception expected"); - } catch (NullPointerException expected) { - assertThat(expected).hasMessageThat().contains("Client SSL requires rootCertInstance"); + } catch (UnsupportedOperationException expected) { + assertThat(expected).hasMessageThat().contains("Unsupported configurations in " + + "UpstreamTlsContext!"); } } + @Test + public void testProviderForClient_rootInstanceNull_but_isUsingSystemRootCerts_valid() + throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + getSslContextProvider( + /* certInstanceName= */ null, + /* rootInstanceName= */ null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .build(), false); + } + + @Test + public void testProviderForClient_deprecatedCertProviderField() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + + // Build UpstreamTlsContext using deprecated field + EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = + new EnvoyServerProtoData.UpstreamTlsContext( + CommonTlsContextTestsUtil.buildCommonTlsContextWithDeprecatedCertProviderInstance( + "gcp_id", + "cert-default", + "gcp_id", + "root-default", + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null)); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); + CertProviderClientSslContextProvider provider = + (CertProviderClientSslContextProvider) + certProviderClientSslContextProviderFactory.getProvider( + upstreamTlsContext, + bootstrapInfo.node().toEnvoyProtoNode(), + bootstrapInfo.certProviders()); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); + + // Generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNotNull(); + assertThat(provider.savedCertChain).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); + + // Generate root cert update + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + } + static class QueuedExecutor implements Executor { /** A list of Runnables to be run in order. */ @VisibleForTesting final Queue runQueue = new ConcurrentLinkedQueue<>(); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java index 82af7d1dc27..93559f47245 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java @@ -32,9 +32,9 @@ import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; -import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil.TestCallback; import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProviderTest.QueuedExecutor; @@ -127,7 +127,7 @@ public void testProviderForServer_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -135,11 +135,11 @@ public void testProviderForServer_mtls() throws Exception { ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -168,7 +168,7 @@ public void testProviderForServer_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -196,7 +196,7 @@ public void testProviderForServer_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -204,11 +204,11 @@ public void testProviderForServer_mtls_newXds() throws Exception { ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -237,7 +237,7 @@ public void testProviderForServer_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -294,14 +294,14 @@ public void testProviderForServer_tls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( CommonCertProviderTestUtils.getPrivateKey(SERVER_0_KEY_FILE), ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStoreTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStoreTest.java index 8f77de7b5e2..c0bc095eab6 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStoreTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStoreTest.java @@ -123,7 +123,6 @@ public void notifyCertUpdatesNotSupported_expectExceptionOnSecondCall() { } @Test - @SuppressWarnings("deprecation") public void onePluginSameConfig_sameInstance() { registerPlugin("plugin1"); CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class); @@ -167,7 +166,6 @@ public void onePluginSameConfig_sameInstance() { } @Test - @SuppressWarnings("deprecation") public void onePluginSameConfig_secondWatcherAfterFirstNotify() { registerPlugin("plugin1"); CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class); @@ -275,7 +273,6 @@ public void twoPlugins_differentInstance() { mockWatcher1, handle1, certProviderProvider1, mockWatcher2, handle2, certProviderProvider2); } - @SuppressWarnings("deprecation") private static void checkDifferentInstances( CertificateProvider.Watcher mockWatcher1, CertificateProviderStore.Handle handle1, diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java index a0bdd618004..304a2dd5441 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java @@ -24,22 +24,28 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import io.grpc.internal.JsonParser; import io.grpc.internal.TimeProvider; import java.io.IOException; +import java.util.Collection; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; +import org.junit.After; +import org.junit.Assume; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; /** Unit tests for {@link FileWatcherCertificateProviderProvider}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class FileWatcherCertificateProviderProviderTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @@ -48,13 +54,28 @@ public class FileWatcherCertificateProviderProviderTest { scheduledExecutorServiceFactory; @Mock private TimeProvider timeProvider; + @Parameter + public boolean enableSpiffe; + private boolean originalEnableSpiffe; private FileWatcherCertificateProviderProvider provider; + @Parameters(name = "enableSpiffe={0}") + public static Collection data() { + return ImmutableList.of(true, false); + } + @Before public void setUp() throws IOException { provider = new FileWatcherCertificateProviderProvider( fileWatcherCertificateProviderFactory, scheduledExecutorServiceFactory, timeProvider); + originalEnableSpiffe = FileWatcherCertificateProviderProvider.enableSpiffe; + FileWatcherCertificateProviderProvider.enableSpiffe = enableSpiffe; + } + + @After + public void restoreEnvironment() { + FileWatcherCertificateProviderProvider.enableSpiffe = originalEnableSpiffe; } @Test @@ -85,6 +106,30 @@ public void createProvider_minimalConfig() throws IOException { eq("/var/run/gke-spiffe/certs/certificates.pem"), eq("/var/run/gke-spiffe/certs/private_key.pem"), eq("/var/run/gke-spiffe/certs/ca_certificates.pem"), + eq(null), + eq(600L), + eq(mockService), + eq(timeProvider)); + } + + @Test + public void createProvider_minimalSpiffeConfig() throws IOException { + Assume.assumeTrue(enableSpiffe); + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(MINIMAL_FILE_WATCHER_WITH_SPIFFE_CONFIG); + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + provider.createCertificateProvider(map, distWatcher, true); + verify(fileWatcherCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(true), + eq("/var/run/gke-spiffe/certs/certificates.pem"), + eq("/var/run/gke-spiffe/certs/private_key.pem"), + eq(null), + eq("/var/run/gke-spiffe/certs/spiffe_bundle.json"), eq(600L), eq(mockService), eq(timeProvider)); @@ -106,6 +151,30 @@ public void createProvider_fullConfig() throws IOException { eq("/var/run/gke-spiffe/certs/certificates2.pem"), eq("/var/run/gke-spiffe/certs/private_key3.pem"), eq("/var/run/gke-spiffe/certs/ca_certificates4.pem"), + eq(null), + eq(7890L), + eq(mockService), + eq(timeProvider)); + } + + @Test + public void createProvider_spiffeConfig() throws IOException { + Assume.assumeTrue(enableSpiffe); + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(FULL_FILE_WATCHER_WITH_SPIFFE_CONFIG); + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + provider.createCertificateProvider(map, distWatcher, true); + verify(fileWatcherCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(true), + eq("/var/run/gke-spiffe/certs/certificates2.pem"), + eq("/var/run/gke-spiffe/certs/private_key3.pem"), + eq(null), + eq("/var/run/gke-spiffe/certs/spiffe_bundle.json"), eq(7890L), eq(mockService), eq(timeProvider)); @@ -157,15 +226,18 @@ public void createProvider_missingKey_expectException() throws IOException { @Test public void createProvider_missingRoot_expectException() throws IOException { + String expectedMessage = enableSpiffe ? "either 'ca_certificate_file' or " + + "'spiffe_trust_bundle_map_file' is required in the config" + : "'ca_certificate_file' is required in the config"; CertificateProvider.DistributorWatcher distWatcher = new CertificateProvider.DistributorWatcher(); @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_ROOT_CONFIG); + Map map = (Map) JsonParser.parse(MISSING_ROOT_AND_SPIFFE_CONFIG); try { provider.createCertificateProvider(map, distWatcher, true); fail("exception expected"); } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'ca_certificate_file' is required in the config"); + assertThat(npe).hasMessageThat().isEqualTo(expectedMessage); } } @@ -176,6 +248,14 @@ public void createProvider_missingRoot_expectException() throws IOException { + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\"" + " }"; + private static final String MINIMAL_FILE_WATCHER_WITH_SPIFFE_CONFIG = + "{\n" + + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\"," + + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"," + + " \"spiffe_trust_bundle_map_file\":" + + " \"/var/run/gke-spiffe/certs/spiffe_bundle.json\"" + + " }"; + private static final String FULL_FILE_WATCHER_CONFIG = "{\n" + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates2.pem\"," @@ -184,6 +264,16 @@ public void createProvider_missingRoot_expectException() throws IOException { + " \"refresh_interval\": \"7890s\"" + " }"; + private static final String FULL_FILE_WATCHER_WITH_SPIFFE_CONFIG = + "{\n" + + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates2.pem\"," + + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key3.pem\"," + + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates4.pem\"," + + " \"spiffe_trust_bundle_map_file\":" + + " \"/var/run/gke-spiffe/certs/spiffe_bundle.json\"," + + " \"refresh_interval\": \"7890s\"" + + " }"; + private static final String MISSING_CERT_CONFIG = "{\n" + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"," @@ -196,7 +286,7 @@ public void createProvider_missingRoot_expectException() throws IOException { + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\"" + " }"; - private static final String MISSING_ROOT_CONFIG = + private static final String MISSING_ROOT_AND_SPIFFE_CONFIG = "{\n" + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\"," + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"" diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java index 210ec056732..620ee0a7ff7 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java @@ -23,6 +23,7 @@ import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SPIFFE_TRUST_MAP_1_FILE; import static java.nio.file.StandardCopyOption.REPLACE_EXISTING; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -47,6 +48,7 @@ import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.concurrent.Delayed; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; @@ -73,6 +75,7 @@ public class FileWatcherCertificateProviderTest { private static final String CERT_FILE = "cert.pem"; private static final String KEY_FILE = "key.pem"; private static final String ROOT_FILE = "root.pem"; + private static final String SPIFFE_TRUST_MAP_FILE = "spiffebundle.json"; @Mock private CertificateProvider.Watcher mockWatcher; @Mock private ScheduledExecutorService timeService; @@ -84,28 +87,33 @@ public class FileWatcherCertificateProviderTest { private String certFile; private String keyFile; private String rootFile; + private String spiffeTrustMapFile; private FileWatcherCertificateProvider provider; + private DistributorWatcher watcher; @Before public void setUp() throws IOException { - DistributorWatcher watcher = new DistributorWatcher(); + watcher = new DistributorWatcher(); watcher.addWatcher(mockWatcher); certFile = new File(tempFolder.getRoot(), CERT_FILE).getAbsolutePath(); keyFile = new File(tempFolder.getRoot(), KEY_FILE).getAbsolutePath(); rootFile = new File(tempFolder.getRoot(), ROOT_FILE).getAbsolutePath(); + spiffeTrustMapFile = new File(tempFolder.getRoot(), SPIFFE_TRUST_MAP_FILE).getAbsolutePath(); provider = - new FileWatcherCertificateProvider( - watcher, true, certFile, keyFile, rootFile, 600L, timeService, timeProvider); + new FileWatcherCertificateProvider(watcher, true, certFile, keyFile, rootFile, null, 600L, + timeService, timeProvider); } private void populateTarget( String certFileSource, String keyFileSource, String rootFileSource, + String spiffeTrustMapFileSource, boolean deleteCurCert, boolean deleteCurKey, + boolean deleteCurSpiffeTrustMap, boolean deleteCurRoot) throws IOException { if (deleteCurCert) { @@ -135,6 +143,17 @@ private void populateTarget( Files.setLastModifiedTime( Paths.get(rootFile), FileTime.fromMillis(timeProvider.currentTimeMillis())); } + if (deleteCurSpiffeTrustMap) { + Files.delete(Paths.get(spiffeTrustMapFile)); + } + if (spiffeTrustMapFileSource != null) { + spiffeTrustMapFileSource = CommonTlsContextTestsUtil + .getTempFileNameForResourcesFile(spiffeTrustMapFileSource); + Files.copy(Paths.get(spiffeTrustMapFileSource), + Paths.get(spiffeTrustMapFile), REPLACE_EXISTING); + Files.setLastModifiedTime( + Paths.get(spiffeTrustMapFile), FileTime.fromMillis(timeProvider.currentTimeMillis())); + } } @Test @@ -144,9 +163,9 @@ public void getCertificateAndCheckUpdates() throws IOException, CertificateExcep doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); provider.checkAndReloadCertificates(); - verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE); + verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE, null); verifyTimeServiceAndScheduledFuture(); reset(mockWatcher, timeService); @@ -165,7 +184,7 @@ public void allUpdateSecondTime() throws IOException, CertificateException, Inte doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); provider.checkAndReloadCertificates(); reset(mockWatcher, timeService); @@ -173,9 +192,10 @@ public void allUpdateSecondTime() throws IOException, CertificateException, Inte .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); timeProvider.forwardTime(1, TimeUnit.SECONDS); - populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); + populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, null, false, false, + false, false); provider.checkAndReloadCertificates(); - verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE); + verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE, null); verifyTimeServiceAndScheduledFuture(); } @@ -186,12 +206,13 @@ public void closeDoesNotScheduleNext() throws IOException, CertificateException doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); provider.close(); provider.checkAndReloadCertificates(); verify(mockWatcher, never()) .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); + verify(mockWatcher, never()).updateSpiffeTrustMap(ArgumentMatchers.anyMap()); verify(timeService, never()).schedule(any(Runnable.class), any(Long.TYPE), any(TimeUnit.class)); verify(timeService, times(1)).shutdownNow(); } @@ -204,7 +225,7 @@ public void rootFileUpdateOnly() throws IOException, CertificateException, Inter doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); provider.checkAndReloadCertificates(); reset(mockWatcher, timeService); @@ -212,9 +233,9 @@ public void rootFileUpdateOnly() throws IOException, CertificateException, Inter .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); timeProvider.forwardTime(1, TimeUnit.SECONDS); - populateTarget(null, null, SERVER_1_PEM_FILE, false, false, false); + populateTarget(null, null, SERVER_1_PEM_FILE, null, false, false, false, false); provider.checkAndReloadCertificates(); - verifyWatcherUpdates(null, SERVER_1_PEM_FILE); + verifyWatcherUpdates(null, SERVER_1_PEM_FILE, null); verifyTimeServiceAndScheduledFuture(); } @@ -226,7 +247,7 @@ public void certAndKeyFileUpdateOnly() doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); provider.checkAndReloadCertificates(); reset(mockWatcher, timeService); @@ -234,9 +255,44 @@ public void certAndKeyFileUpdateOnly() .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); timeProvider.forwardTime(1, TimeUnit.SECONDS); - populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, null, false, false, false); + populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, null, null, false, false, false, false); provider.checkAndReloadCertificates(); - verifyWatcherUpdates(SERVER_0_PEM_FILE, null); + verifyWatcherUpdates(SERVER_0_PEM_FILE, null, null); + verifyTimeServiceAndScheduledFuture(); + } + + @Test + public void spiffeTrustMapFileUpdateOnly() throws Exception { + provider = new FileWatcherCertificateProvider(watcher, true, certFile, keyFile, null, + spiffeTrustMapFile, 600L, timeService, timeProvider); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, null, false, false, false, false); + provider.checkAndReloadCertificates(); + verify(mockWatcher, never()).updateSpiffeTrustMap(ArgumentMatchers.anyMap()); + + reset(timeService); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + timeProvider.forwardTime(1, TimeUnit.SECONDS); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, SPIFFE_TRUST_MAP_FILE, false, + false, false, false); + provider.checkAndReloadCertificates(); + verify(mockWatcher, times(1)).updateSpiffeTrustMap(ArgumentMatchers.anyMap()); + + reset(timeService); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + timeProvider.forwardTime(1, TimeUnit.SECONDS); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, SPIFFE_TRUST_MAP_1_FILE, false, + false, false, false); + provider.checkAndReloadCertificates(); + verify(mockWatcher, times(2)).updateSpiffeTrustMap(ArgumentMatchers.anyMap()); verifyTimeServiceAndScheduledFuture(); } @@ -247,7 +303,7 @@ public void getCertificate_initialMissingCertFile() throws IOException { doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(null, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + populateTarget(null, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 0, 1, "cert.pem"); } @@ -255,13 +311,14 @@ public void getCertificate_initialMissingCertFile() throws IOException { @Test public void getCertificate_missingCertFile() throws IOException, InterruptedException { commonErrorTest( - null, CLIENT_KEY_FILE, CA_PEM_FILE, NoSuchFileException.class, 0, 1, 0, 0, "cert.pem"); + null, CLIENT_KEY_FILE, CA_PEM_FILE, null, NoSuchFileException.class, 0, 1, 0, 0, + "cert.pem"); } @Test public void getCertificate_missingKeyFile() throws IOException, InterruptedException { commonErrorTest( - CLIENT_PEM_FILE, null, CA_PEM_FILE, NoSuchFileException.class, 0, 1, 0, 0, "key.pem"); + CLIENT_PEM_FILE, null, CA_PEM_FILE, null, NoSuchFileException.class, 0, 1, 0, 0, "key.pem"); } @Test @@ -270,6 +327,7 @@ public void getCertificate_badKeyFile() throws IOException, InterruptedException CLIENT_PEM_FILE, SERVER_0_PEM_FILE, CA_PEM_FILE, + null, java.security.spec.InvalidKeySpecException.class, 0, 1, @@ -285,12 +343,13 @@ public void getCertificate_missingRootFile() throws IOException, InterruptedExce doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); + populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, null, false, false, + false, false); provider.checkAndReloadCertificates(); reset(mockWatcher); timeProvider.forwardTime(1, TimeUnit.SECONDS); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, false, false, true); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, null, false, false, false, true); timeProvider.forwardTime( CERT0_EXPIRY_TIME_MILLIS - 610_000L - timeProvider.currentTimeMillis(), TimeUnit.MILLISECONDS); @@ -302,6 +361,7 @@ private void commonErrorTest( String certFile, String keyFile, String rootFile, + String spiffeFile, Class throwableType, int firstUpdateCertCount, int firstUpdateRootCount, @@ -314,13 +374,15 @@ private void commonErrorTest( doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); + populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, + SPIFFE_TRUST_MAP_1_FILE, false, false, false, false); provider.checkAndReloadCertificates(); reset(mockWatcher); timeProvider.forwardTime(1, TimeUnit.SECONDS); populateTarget( - certFile, keyFile, rootFile, certFile == null, keyFile == null, rootFile == null); + certFile, keyFile, rootFile, spiffeFile, certFile == null, keyFile == null, + rootFile == null, spiffeFile == null); timeProvider.forwardTime( CERT0_EXPIRY_TIME_MILLIS - 610_000L - timeProvider.currentTimeMillis(), TimeUnit.MILLISECONDS); @@ -372,7 +434,7 @@ private void verifyTimeServiceAndScheduledFuture() { assertThat(provider.scheduledFuture.isCancelled()).isFalse(); } - private void verifyWatcherUpdates(String certPemFile, String rootPemFile) + private void verifyWatcherUpdates(String certPemFile, String rootPemFile, String spiffeFile) throws IOException, CertificateException { if (certPemFile != null) { @SuppressWarnings("unchecked") @@ -399,6 +461,17 @@ private void verifyWatcherUpdates(String certPemFile, String rootPemFile) } else { verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); } + if (spiffeFile != null) { + @SuppressWarnings("unchecked") + ArgumentCaptor>> spiffeCaptor = + ArgumentCaptor.forClass(Map.class); + verify(mockWatcher, times(1)).updateSpiffeTrustMap(spiffeCaptor.capture()); + Map> trustMap = spiffeCaptor.getValue(); + assertThat(trustMap).hasSize(2); + verify(mockWatcher, never()).onError(any(Status.class)); + } else { + verify(mockWatcher, never()).updateSpiffeTrustMap(ArgumentMatchers.anyMap()); + } } static class TestScheduledFuture implements ScheduledFuture { diff --git a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java index 77749814cf2..3077482b10b 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java @@ -23,6 +23,8 @@ import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.protobuf.ByteString; import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; @@ -89,7 +91,7 @@ public void constructor_fromRootCert() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, false); assertThat(factory).isNotNull(); TrustManager[] tms = factory.getTrustManagers(); assertThat(tms).isNotNull(); @@ -105,6 +107,46 @@ public void constructor_fromRootCert() .isEqualTo(CertificateUtils.toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))[0]); } + @Test + public void constructor_fromSpiffeTrustMap() + throws CertificateException, IOException, CertStoreException { + X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); + CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", + "san2"); + // Single domain and single cert + XdsTrustManagerFactory factory = new XdsTrustManagerFactory(ImmutableMap + .of("example.com", ImmutableList.of(x509Cert)), staticValidationContext, false); + assertThat(factory).isNotNull(); + TrustManager[] tms = factory.getTrustManagers(); + assertThat(tms).isNotNull(); + assertThat(tms).hasLength(1); + TrustManager myTm = tms[0]; + assertThat(myTm).isInstanceOf(XdsX509TrustManager.class); + XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) myTm; + assertThat(xdsX509TrustManager.getAcceptedIssuers()).isNotNull(); + assertThat(xdsX509TrustManager.getAcceptedIssuers()).hasLength(1); + assertThat(xdsX509TrustManager.getAcceptedIssuers()[0].getIssuerX500Principal().getName()) + .isEqualTo("CN=testca,O=Internet Widgits Pty Ltd,ST=Some-State,C=AU"); + // Multiple domains and multiple certs for one of it + X509Certificate anotherCert = TestUtils.loadX509Cert(CLIENT_PEM_FILE); + factory = new XdsTrustManagerFactory(ImmutableMap + .of("example.com", ImmutableList.of(x509Cert), + "google.com", ImmutableList.of(x509Cert, anotherCert)), staticValidationContext, false); + assertThat(factory).isNotNull(); + tms = factory.getTrustManagers(); + assertThat(tms).isNotNull(); + assertThat(tms).hasLength(1); + myTm = tms[0]; + assertThat(myTm).isInstanceOf(XdsX509TrustManager.class); + xdsX509TrustManager = (XdsX509TrustManager) myTm; + assertThat(xdsX509TrustManager.getAcceptedIssuers()).isNotNull(); + assertThat(xdsX509TrustManager.getAcceptedIssuers()).hasLength(2); + assertThat(xdsX509TrustManager.getAcceptedIssuers()[0].getIssuerX500Principal().getName()) + .isEqualTo("CN=testca,O=Internet Widgits Pty Ltd,ST=Some-State,C=AU"); + assertThat(xdsX509TrustManager.getAcceptedIssuers()[1].getIssuerX500Principal().getName()) + .isEqualTo("CN=testca,O=Internet Widgits Pty Ltd,ST=Some-State,C=AU"); + } + @Test public void constructorRootCert_checkServerTrusted() throws CertificateException, IOException, CertStoreException { @@ -112,7 +154,7 @@ public void constructorRootCert_checkServerTrusted() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "waterzooi.test.google.be"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, false); XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] serverChain = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); @@ -125,7 +167,7 @@ public void constructorRootCert_nonStaticContext_throwsException() X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); try { new XdsTrustManagerFactory( - new X509Certificate[] {x509Cert}, getCertContextFromPath(CA_PEM_FILE)); + new X509Certificate[] {x509Cert}, getCertContextFromPath(CA_PEM_FILE), false); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected) @@ -134,6 +176,19 @@ public void constructorRootCert_nonStaticContext_throwsException() } } + @Test + public void constructorRootCert_nonStaticContext_systemRootCerts_valid() + throws CertificateException, IOException, CertStoreException { + X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); + CertificateValidationContext certValidationContext = CertificateValidationContext.newBuilder() + .setTrustedCa( + DataSource.newBuilder().setFilename(TestUtils.loadCert(CA_PEM_FILE).getAbsolutePath())) + .setSystemRootCerts(CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(); + XdsTrustManagerFactory unused = + new XdsTrustManagerFactory(new X509Certificate[] {x509Cert}, certValidationContext, false); + } + @Test public void constructorRootCert_checkServerTrusted_throwsException() throws CertificateException, IOException, CertStoreException { @@ -141,7 +196,7 @@ public void constructorRootCert_checkServerTrusted_throwsException() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, false); XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] serverChain = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); @@ -162,7 +217,7 @@ public void constructorRootCert_checkClientTrusted_throwsException() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, false); XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] clientChain = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java index 9ceb6f706fe..ffe0536f25b 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java @@ -18,9 +18,13 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.BAD_WILDCARD_DNS_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_SPIFFE_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_SPIFFE_PEM_FILE; import static org.junit.Assert.fail; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.doReturn; @@ -30,6 +34,7 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; @@ -38,6 +43,9 @@ import java.security.cert.CertStoreException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.List; import javax.net.ssl.SSLEngine; @@ -48,7 +56,8 @@ import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -56,7 +65,7 @@ /** * Unit tests for {@link XdsX509TrustManager}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class XdsX509TrustManagerTest { @Rule @@ -70,32 +79,40 @@ public class XdsX509TrustManagerTest { private XdsX509TrustManager trustManager; + private final TestParam testParam; + + public XdsX509TrustManagerTest(TestParam testParam) { + this.testParam = testParam; + } + @Test public void nullCertContextTest() throws CertificateException, IOException { - trustManager = new XdsX509TrustManager(null, mockDelegate); + trustManager = new XdsX509TrustManager(null, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, new ArrayList<>()); } @Test + @SuppressWarnings("deprecation") public void emptySanListContextTest() throws CertificateException, IOException { CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void missingPeerCerts() { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); try { - trustManager.verifySubjectAltNameInChain(null); + trustManager.verifySubjectAltNameInChain(null, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate(s) missing"); @@ -103,14 +120,15 @@ public void missingPeerCerts() { } @Test + @SuppressWarnings("deprecation") public void emptyArrayPeerCerts() { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); try { - trustManager.verifySubjectAltNameInChain(new X509Certificate[0]); + trustManager.verifySubjectAltNameInChain( + new X509Certificate[0], certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate(s) missing"); @@ -118,16 +136,16 @@ public void emptyArrayPeerCerts() { } @Test + @SuppressWarnings("deprecation") public void noSansInPeerCerts() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(CLIENT_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -135,22 +153,23 @@ public void noSansInPeerCerts() throws CertificateException, IOException { } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsVerifies() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setExact("waterzooi.test.google.be") .setIgnoreCase(false) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsVerifies_differentCase_expectException() throws CertificateException, IOException { StringMatcher stringMatcher = @@ -158,14 +177,13 @@ public void oneSanInPeerCertsVerifies_differentCase_expectException() .setExact("waterZooi.test.Google.be") .setIgnoreCase(false) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -173,47 +191,48 @@ public void oneSanInPeerCertsVerifies_differentCase_expectException() } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsVerifies_ignoreCase() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("Waterzooi.Test.google.be").setIgnoreCase(true).build(); @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_prefix() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setPrefix("waterzooi.") // test.google.be .setIgnoreCase(false) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsPrefix_differentCase_expectException() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setPrefix("waterZooi.").setIgnoreCase(false).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -221,47 +240,47 @@ public void oneSanInPeerCertsPrefix_differentCase_expectException() } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_prefixIgnoreCase() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setPrefix("WaterZooi.") // test.google.be .setIgnoreCase(true) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_suffix() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setSuffix(".google.be").setIgnoreCase(false).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsSuffix_differentCase_expectException() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setSuffix(".gooGle.bE").setIgnoreCase(false).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -269,44 +288,45 @@ public void oneSanInPeerCertsSuffix_differentCase_expectException() } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_suffixIgnoreCase() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setSuffix(".GooGle.BE").setIgnoreCase(true).build(); @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_substring() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setContains("zooi.test.google").setIgnoreCase(false).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsSubstring_differentCase_expectException() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setContains("zooi.Test.gooGle").setIgnoreCase(false).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -314,81 +334,81 @@ public void oneSanInPeerCertsSubstring_differentCase_expectException() } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_substringIgnoreCase() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setContains("zooI.Test.Google").setIgnoreCase(true).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_safeRegex() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setSafeRegex( RegexMatcher.newBuilder().setRegex("water[[:alpha:]]{1}ooi\\.test\\.google\\.be")) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_safeRegex1() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setSafeRegex( RegexMatcher.newBuilder().setRegex("no-match-string|\\*\\.test\\.youtube\\.com")) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_safeRegex_ipAddress() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setSafeRegex( RegexMatcher.newBuilder().setRegex("([[:digit:]]{1,3}\\.){3}[[:digit:]]{1,3}")) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_safeRegex_noMatch() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setSafeRegex( RegexMatcher.newBuilder().setRegex("water[[:alpha:]]{2}ooi\\.test\\.google\\.be")) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -396,35 +416,35 @@ public void oneSanInPeerCerts_safeRegex_noMatch() throws CertificateException, I } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsVerifiesMultipleVerifySans() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setExact("waterzooi.test.google.be").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) .build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsNotFoundException() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -432,42 +452,43 @@ public void oneSanInPeerCertsNotFoundException() } @Test + @SuppressWarnings("deprecation") public void wildcardSanInPeerCertsVerifiesMultipleVerifySans() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setSuffix("test.youTube.Com").setIgnoreCase(true).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) // should match suffix test.youTube.Com .build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void wildcardSanInPeerCertsVerifiesMultipleVerifySans1() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setContains("est.Google.f").setIgnoreCase(true).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) // should contain est.Google.f .build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void wildcardSanInPeerCertsSubdomainMismatch() throws CertificateException, IOException { // 2. Asterisk (*) cannot match across domain name labels. @@ -475,14 +496,13 @@ public void wildcardSanInPeerCertsSubdomainMismatch() // sub.test.example.com. StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("sub.abc.test.youtube.com").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -490,36 +510,36 @@ public void wildcardSanInPeerCertsSubdomainMismatch() } @Test + @SuppressWarnings("deprecation") public void oneIpAddressInPeerCertsVerifies() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setExact("192.168.1.3").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) .build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneIpAddressInPeerCertsMismatch() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setExact("192.168.2.3").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) .build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -537,6 +557,71 @@ public void checkServerTrustedSslEngine() assertThat(sslEngine.getSSLParameters().getEndpointIdentificationAlgorithm()).isEmpty(); } + @Test + public void checkServerTrustedSslEngineSpiffeTrustMap() + throws CertificateException, IOException, CertStoreException { + TestSslEngine sslEngine = buildTrustManagerAndGetSslEngine(); + X509Certificate[] serverCerts = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_SPIFFE_PEM_FILE)); + List caCerts = Arrays.asList(CertificateUtils + .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); + trustManager = XdsTrustManagerFactory.createX509TrustManager( + ImmutableMap.of("example.com", caCerts), null, false); + trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); + verify(sslEngine, times(1)).getHandshakeSession(); + assertThat(sslEngine.getSSLParameters().getEndpointIdentificationAlgorithm()).isEmpty(); + } + + @Test + public void checkServerTrustedSslEngineSpiffeTrustMap_missing_spiffe_id() + throws CertificateException, IOException, CertStoreException { + TestSslEngine sslEngine = buildTrustManagerAndGetSslEngine(); + X509Certificate[] serverCerts = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + List caCerts = Arrays.asList(CertificateUtils + .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); + trustManager = XdsTrustManagerFactory.createX509TrustManager( + ImmutableMap.of("example.com", caCerts), null, false); + try { + trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); + fail("exception expected"); + } catch (CertificateException expected) { + assertThat(expected).hasMessageThat() + .isEqualTo("Failed to extract SPIFFE ID from peer leaf certificate"); + } + } + + @Test + public void checkServerTrustedSpiffeSslEngineTrustMap_missing_trust_domain() + throws CertificateException, IOException, CertStoreException { + TestSslEngine sslEngine = buildTrustManagerAndGetSslEngine(); + X509Certificate[] serverCerts = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_SPIFFE_PEM_FILE)); + List caCerts = Arrays.asList(CertificateUtils + .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); + trustManager = XdsTrustManagerFactory.createX509TrustManager( + ImmutableMap.of("unknown.com", caCerts), null, false); + try { + trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); + fail("exception expected"); + } catch (CertificateException expected) { + assertThat(expected).hasMessageThat().isEqualTo("Spiffe Trust Map doesn't contain trust" + + " domain 'example.com' from peer leaf certificate"); + } + } + + @Test + public void checkClientTrustedSpiffeTrustMap() + throws CertificateException, IOException, CertStoreException { + X509Certificate[] clientCerts = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(CLIENT_SPIFFE_PEM_FILE)); + List caCerts = Arrays.asList(CertificateUtils + .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); + trustManager = XdsTrustManagerFactory.createX509TrustManager( + ImmutableMap.of("foo.bar.com", caCerts), null, false); + trustManager.checkClientTrusted(clientCerts, "RSA"); + } + @Test public void checkServerTrustedSslEngine_untrustedServer_expectException() throws CertificateException, IOException, CertStoreException { @@ -565,6 +650,22 @@ public void checkServerTrustedSslSocket() assertThat(sslSocket.getSSLParameters().getEndpointIdentificationAlgorithm()).isEmpty(); } + @Test + public void checkServerTrustedSslSocketSpiffeTrustMap() + throws CertificateException, IOException, CertStoreException { + TestSslSocket sslSocket = buildTrustManagerAndGetSslSocket(); + X509Certificate[] serverCerts = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_SPIFFE_PEM_FILE)); + List caCerts = Arrays.asList(CertificateUtils + .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); + trustManager = XdsTrustManagerFactory.createX509TrustManager( + ImmutableMap.of("example.com", caCerts), null, false); + trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslSocket); + verify(sslSocket, times(1)).isConnected(); + verify(sslSocket, times(1)).getHandshakeSession(); + assertThat(sslSocket.getSSLParameters().getEndpointIdentificationAlgorithm()).isEmpty(); + } + @Test public void checkServerTrustedSslSocket_untrustedServer_expectException() throws CertificateException, IOException, CertStoreException { @@ -583,29 +684,76 @@ public void checkServerTrustedSslSocket_untrustedServer_expectException() } @Test - public void unsupportedAltNameType() throws CertificateException, IOException { + @SuppressWarnings("deprecation") + public void unsupportedAltNameType() throws CertificateException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setExact("waterzooi.test.google.be") .setIgnoreCase(false) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate mockCert = mock(X509Certificate.class); when(mockCert.getSubjectAlternativeNames()) .thenReturn(Collections.>singleton(ImmutableList.of(Integer.valueOf(1), "foo"))); X509Certificate[] certs = new X509Certificate[] {mockCert}; try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); } } + @Test + @SuppressWarnings("deprecation") + public void testDnsWildcardPatterns() + throws CertificateException, IOException { + StringMatcher stringMatcher = + StringMatcher.newBuilder() + .setExact(testParam.sanPattern) + .setIgnoreCase(testParam.ignoreCase) + .build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder() + .addMatchSubjectAltNames(stringMatcher) + .build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); + X509Certificate[] certs = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(testParam.certFile)); + try { + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); + assertThat(testParam.expected).isTrue(); + } catch (CertificateException certException) { + assertThat(testParam.expected).isFalse(); + assertThat(certException).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); + } + } + + @Parameters(name = "{index}: {0}") + public static Collection getParameters() { + return Arrays.asList(new Object[][] { + {new TestParam("*.test.google.fr", SERVER_1_PEM_FILE, false, true)}, + {new TestParam("*.test.youtube.com", SERVER_1_PEM_FILE, false, true)}, + {new TestParam("waterzooi.test.google.be", SERVER_1_PEM_FILE, false, true)}, + {new TestParam("192.168.1.3", SERVER_1_PEM_FILE, false, true)}, + {new TestParam("*.TEST.YOUTUBE.com", SERVER_1_PEM_FILE, true, true)}, + {new TestParam("w*i.test.google.be", SERVER_1_PEM_FILE, false, true)}, + {new TestParam("w*a.test.google.be", SERVER_1_PEM_FILE, false, false)}, + {new TestParam("*.test.google.com.au", SERVER_0_PEM_FILE, false, false)}, + {new TestParam("*.TEST.YOUTUBE.com", SERVER_1_PEM_FILE, false, false)}, + {new TestParam("*waterzooi", SERVER_1_PEM_FILE, false, false)}, + {new TestParam("*.lyft.com", BAD_WILDCARD_DNS_PEM_FILE, false, false)}, + {new TestParam("ly**ft.com", BAD_WILDCARD_DNS_PEM_FILE, false, false)}, + {new TestParam("*yft.c*m", BAD_WILDCARD_DNS_PEM_FILE, false, false)}, + {new TestParam("xn--*.lyft.com", BAD_WILDCARD_DNS_PEM_FILE, false, false)}, + {new TestParam("", BAD_WILDCARD_DNS_PEM_FILE, false, false)}, + }); + } + private TestSslEngine buildTrustManagerAndGetSslEngine() throws CertificateException, IOException, CertStoreException { SSLParameters sslParams = buildTrustManagerAndGetSslParameters(); @@ -632,7 +780,7 @@ private SSLParameters buildTrustManagerAndGetSslParameters() X509Certificate[] caCerts = CertificateUtils.toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE)); trustManager = XdsTrustManagerFactory.createX509TrustManager(caCerts, - null); + null, false); when(mockSession.getProtocol()).thenReturn("TLSv1.2"); when(mockSession.getPeerHost()).thenReturn("peer-host-from-mock"); SSLParameters sslParams = new SSLParameters(); @@ -669,4 +817,18 @@ public void setSSLParameters(SSLParameters sslParameters) { private SSLParameters sslParameters; } + + private static class TestParam { + final String sanPattern; + final String certFile; + final boolean ignoreCase; + final boolean expected; + + TestParam(String sanPattern, String certFile, boolean ignoreCase, boolean expected) { + this.sanPattern = sanPattern; + this.certFile = certFile; + this.ignoreCase = ignoreCase; + this.expected = expected; + } + } } diff --git a/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java b/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java new file mode 100644 index 00000000000..db9168dd08e --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.orca; + +import io.grpc.LoadBalancer; + +/** + * Accessor for white-box testing involving OrcaOobUtil. + */ +public final class OrcaOobUtilAccessor { + private OrcaOobUtilAccessor() { + // Do not instantiate + } + + public static LoadBalancer.SubchannelPicker getDelegate(LoadBalancer.SubchannelPicker picker) { + if (picker instanceof OrcaOobUtil.OrcaReportingHelper.OrcaOobPicker) { + return ((OrcaOobUtil.OrcaReportingHelper.OrcaOobPicker) picker).delegate; + } + return picker; + } +} diff --git a/xds/src/test/resources/certs/sni-test-certs/README b/xds/src/test/resources/certs/sni-test-certs/README new file mode 100644 index 00000000000..25e66021192 --- /dev/null +++ b/xds/src/test/resources/certs/sni-test-certs/README @@ -0,0 +1,55 @@ +Bad Wildcard DNS Certificate (bad_wildcard_dns_certificate.pem) +This certificate is used for testing SNI with invalid wildcard DNS SANs. It is issued by a custom, self-signed Certificate Authority (CA). + +1. Create the Certificate Authority (CA) +Create the CA's private key: +$ openssl genpkey -algorithm RSA -out ca.key -pkeyopt rsa_keygen_bits:2048 +Create the CA's self-signed certificate: +$ openssl req -x509 -new -nodes -key ca.key -sha256 -days 365 -out ca.pem -subj "/CN=My Internal CA" + +2. Generate the Server Certificate +Next, generate the server's private key and a Certificate Signing Request (CSR). +Create the server's private key: +$ openssl genpkey -algorithm RSA -out bad_wildcard_dns.key -pkeyopt rsa_keygen_bits:2048 +Create a configuration file named san.cnf with the following content. This file specifies the Subject Alternative Names (SANs) for the certificate. +[req] +distinguished_name = req_distinguished_name +req_extensions = v3_req +prompt = no + +[req_distinguished_name] +C = US +ST = Illinois +L = Chicago +O = "Example, Co." +CN = *.test.google.com + +[v3_req] +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +extendedKeyUsage = serverAuth +subjectAltName = @alt_names + +[alt_names] +DNS.1 = *.test.google.fr +DNS.2 = *.test.youtube.com +DNS.3 = waterzooi.test.google.be +DNS.4 = 192.168.1.3 +DNS.5 = *.TEST.YOUTUBE.com +DNS.6 = w*i.test.google.be +DNS.7 = w*a.test.google.be +DNS.8 = *.test.google.com.au +DNS.9 = *waterzooi +DNS.10 = *.lyft.com +DNS.11 = ly**ft.com +DNS.12 = *yft.c*m +DNS.13 = xn--*.lyft.com + +Create the Certificate Signing Request (CSR): +$ openssl req -new -key bad_wildcard_dns.key -out bad_wildcard_dns.csr -config san.cnf + +3. Sign the Server Certificate +Finally, use the CA to sign the CSR, which will create the server certificate. +$ openssl x509 -req -in bad_wildcard_dns.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out bad_wildcard_dns_certificate.pem -days 365 -sha256 -extensions v3_req -extfile san.cnf + +4. Clean Up +$ rm bad_wildcard_dns.key san.cnf bad_wildcard_dns.csr ca.key ca.pem ca.srl diff --git a/xds/src/test/resources/certs/sni-test-certs/bad_wildcard_dns_certificate.pem b/xds/src/test/resources/certs/sni-test-certs/bad_wildcard_dns_certificate.pem new file mode 100644 index 00000000000..b015f62e51c --- /dev/null +++ b/xds/src/test/resources/certs/sni-test-certs/bad_wildcard_dns_certificate.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDsjCCApqgAwIBAgIUCs5j4C2KXgCRVFa48kc5TYRS1JwwDQYJKoZIhvcNAQEL +BQAwGTEXMBUGA1UEAwwOTXkgSW50ZXJuYWwgQ0EwIBcNMjUwOTIzMDc1NDUzWhgP +MjEyNTA4MzAwNzU0NTNaMGUxCzAJBgNVBAYTAlVTMREwDwYDVQQIDAhJbGxpbm9p +czEQMA4GA1UEBwwHQ2hpY2FnbzEVMBMGA1UECgwMRXhhbXBsZSwgQ28uMRowGAYD +VQQDDBEqLnRlc3QuZ29vZ2xlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC +AQoCggEBAKoqcnNh9MV39GH6JjC5KVMN6MO1IoTw6wHJN0JJ/nGNx6ycIsBK8SgJ +eYRR2BEpT6WZba+f04KChcB4Z9tiPISNvUBpmEv76rAsdtcAZwSpF06q4wxHVE5F +rX6mNT8hk448mDBDGHUXNAT6g/e/Vlt6U0XRyuu713gbZq1X6JH29FG7EJ3LUx35 +h6sEkvTlZZ3m6NJr7zYoqrYh/gRkPigtPxaNcoXo0gVm4IEde0sYz27SWyNH4v/o +23NynSulOwx4DwEhBOXekLb5QJHBqwMTPynaMncBQIXF+PXeuxN9a3zR6DSn+jGw +g008tS0tn2FuAvJDBl0paEykdOr2rNMCAwEAAaOBozCBoDALBgNVHQ8EBAMCBeAw +EwYDVR0lBAwwCgYIKwYBBQUHAwEwPAYDVR0RBDUwM4IKbHkqKmZ0LmNvbYIIKnlm +dC5jKm2CCS5seWZ0LmNvbYIOeG4tLSoubHlmdC5jb22CADAdBgNVHQ4EFgQUZoL2 +OzBtK/BUzSYfgXDx3iDjcIQwHwYDVR0jBBgwFoAUHlstFN5WSLSqyJgUDy6BB0z0 +BrgwDQYJKoZIhvcNAQELBQADggEBAMYwVOT7XZJMQ6n32pQqtZhJ/Z0wlVfCAbm0 +7xospeBt6KtOz2zIsvPpq0aqPjowMAeL1EZaBvmfm/XgWUU5e/3hLUIHOHyKfswB +czDbY0RE8nfVDoF4Ck1ljPjvrFr4tSAxTzVA4JU5o3UXkblBg0LG6tTuLlZ3x5aF +KtkZnszxjE+vOg6J9MDbFP/xtA1oVHyCvk+cUgnBxAoPShI+87DINGVTmztBSetK +nJN9dOh7Q88NhTLHOe67Ora9Y0ZP+uFKHaqFv8qj8B/Q6ptb0CAksdL5EunkIHrq +glKdVdYgIP2JpRwtvVHK5FzWBlGXCi3DxTyYi6FWqsSJ+heCS2w= +-----END CERTIFICATE----- diff --git a/xds/third_party/envoy/import.sh b/xds/third_party/envoy/import.sh index 41506c2ed32..74b8af750ab 100755 --- a/xds/third_party/envoy/import.sh +++ b/xds/third_party/envoy/import.sh @@ -16,8 +16,8 @@ # Update VERSION then execute this script set -e -# import VERSION from the google internal copybara_version.txt for Envoy -VERSION=ab911ac2ff971f805ec822ad4d4ff6b42a61cc7c +# import VERSION from the google internal go/envoy-import-status +VERSION=a0b3df32ba54c92a08d3636a9a36013cb920e471 DOWNLOAD_URL="https://github.com/envoyproxy/envoy/archive/${VERSION}.tar.gz" DOWNLOAD_BASE_DIR="envoy-${VERSION}" SOURCE_PROTO_BASE_DIR="${DOWNLOAD_BASE_DIR}/api" @@ -33,9 +33,11 @@ envoy/config/cluster/v3/circuit_breaker.proto envoy/config/cluster/v3/cluster.proto envoy/config/cluster/v3/filter.proto envoy/config/cluster/v3/outlier_detection.proto +envoy/config/common/mutation_rules/v3/mutation_rules.proto envoy/config/core/v3/address.proto envoy/config/core/v3/backoff.proto envoy/config/core/v3/base.proto +envoy/config/core/v3/cel.proto envoy/config/core/v3/config_source.proto envoy/config/core/v3/event_service_config.proto envoy/config/core/v3/extension.proto @@ -46,6 +48,7 @@ envoy/config/core/v3/http_uri.proto envoy/config/core/v3/protocol.proto envoy/config/core/v3/proxy_protocol.proto envoy/config/core/v3/resolver.proto +envoy/config/core/v3/socket_cmsg_headers.proto envoy/config/core/v3/socket_option.proto envoy/config/core/v3/substitution_format_string.proto envoy/config/core/v3/udp_socket_config.proto @@ -73,11 +76,21 @@ envoy/config/trace/v3/zipkin.proto envoy/data/accesslog/v3/accesslog.proto envoy/extensions/clusters/aggregate/v3/cluster.proto envoy/extensions/filters/common/fault/v3/fault.proto +envoy/extensions/filters/http/ext_authz/v3/ext_authz.proto +envoy/extensions/common/matching/v3/extension_matcher.proto envoy/extensions/filters/http/fault/v3/fault.proto +envoy/extensions/filters/http/composite/v3/composite.proto envoy/extensions/filters/http/rate_limit_quota/v3/rate_limit_quota.proto +envoy/extensions/filters/http/gcp_authn/v3/gcp_authn.proto envoy/extensions/filters/http/rbac/v3/rbac.proto envoy/extensions/filters/http/router/v3/router.proto envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto +envoy/extensions/grpc_service/call_credentials/access_token/v3/access_token_credentials.proto +envoy/extensions/grpc_service/channel_credentials/google_default/v3/google_default_credentials.proto +envoy/extensions/grpc_service/channel_credentials/insecure/v3/insecure_credentials.proto +envoy/extensions/grpc_service/channel_credentials/local/v3/local_credentials.proto +envoy/extensions/grpc_service/channel_credentials/tls/v3/tls_credentials.proto +envoy/extensions/grpc_service/channel_credentials/xds/v3/xds_credentials.proto envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3/client_side_weighted_round_robin.proto envoy/extensions/load_balancing_policies/common/v3/common.proto envoy/extensions/load_balancing_policies/least_request/v3/least_request.proto @@ -85,20 +98,25 @@ envoy/extensions/load_balancing_policies/pick_first/v3/pick_first.proto envoy/extensions/load_balancing_policies/ring_hash/v3/ring_hash.proto envoy/extensions/load_balancing_policies/round_robin/v3/round_robin.proto envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto +envoy/extensions/transport_sockets/http_11_proxy/v3/upstream_http_11_connect.proto envoy/extensions/transport_sockets/tls/v3/cert.proto envoy/extensions/transport_sockets/tls/v3/common.proto envoy/extensions/transport_sockets/tls/v3/secret.proto envoy/extensions/transport_sockets/tls/v3/tls.proto +envoy/service/auth/v3/attribute_context.proto +envoy/service/auth/v3/external_auth.proto envoy/service/discovery/v3/ads.proto envoy/service/discovery/v3/discovery.proto envoy/service/load_stats/v3/lrs.proto envoy/service/rate_limit_quota/v3/rlqs.proto envoy/service/status/v3/csds.proto envoy/type/http/v3/path_transformation.proto +envoy/type/matcher/v3/address.proto envoy/type/matcher/v3/filter_state.proto envoy/type/matcher/v3/http_inputs.proto envoy/type/matcher/v3/metadata.proto envoy/type/matcher/v3/node.proto +envoy/config/common/matcher/v3/matcher.proto envoy/type/matcher/v3/number.proto envoy/type/matcher/v3/path.proto envoy/type/matcher/v3/regex.proto diff --git a/xds/third_party/envoy/src/main/proto/envoy/admin/v3/config_dump_shared.proto b/xds/third_party/envoy/src/main/proto/envoy/admin/v3/config_dump_shared.proto index 8de77e18e1f..b34e004d986 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/admin/v3/config_dump_shared.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/admin/v3/config_dump_shared.proto @@ -39,6 +39,14 @@ enum ClientResourceStatus { // Client received this resource and replied with NACK. NACKED = 4; + + // Client received an error from the control plane. The attached config + // dump is the most recent accepted one. If no config is accepted yet, + // the attached config dump will be empty. + RECEIVED_ERROR = 5; + + // Client timed out waiting for the resource from the control plane. + TIMEOUT = 6; } message UpdateFailureState { diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto b/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto index 5599f8082d3..f273f2e695f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto @@ -108,6 +108,9 @@ message ComparisonFilter { // <= LE = 2; + + // != + NE = 3; } // Comparison operator. @@ -152,35 +155,38 @@ message TraceableFilter { "envoy.config.filter.accesslog.v2.TraceableFilter"; } -// Filters for random sampling of requests. +// Filters requests based on runtime-configurable sampling rates. message RuntimeFilter { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.accesslog.v2.RuntimeFilter"; - // Runtime key to get an optional overridden numerator for use in the - // ``percent_sampled`` field. If found in runtime, this value will replace the - // default numerator. + // Specifies a key used to look up a custom sampling rate from the runtime configuration. If a value is found for this + // key, it will override the default sampling rate specified in ``percent_sampled``. string runtime_key = 1 [(validate.rules).string = {min_len: 1}]; - // The default sampling percentage. If not specified, defaults to 0% with - // denominator of 100. + // Defines the default sampling percentage when no runtime override is present. If not specified, the default is + // **0%** (with a denominator of 100). type.v3.FractionalPercent percent_sampled = 2; - // By default, sampling pivots on the header - // :ref:`x-request-id` being - // present. If :ref:`x-request-id` - // is present, the filter will consistently sample across multiple hosts based - // on the runtime key value and the value extracted from - // :ref:`x-request-id`. If it is - // missing, or ``use_independent_randomness`` is set to true, the filter will - // randomly sample based on the runtime key value alone. - // ``use_independent_randomness`` can be used for logging kill switches within - // complex nested :ref:`AndFilter - // ` and :ref:`OrFilter - // ` blocks that are easier to - // reason about from a probability perspective (i.e., setting to true will - // cause the filter to behave like an independent random variable when - // composed within logical operator filters). + // Controls how sampling decisions are made. + // + // - Default behavior (``false``): + // + // * Uses the :ref:`x-request-id` as a consistent sampling pivot. + // * When :ref:`x-request-id` is present, sampling will be consistent + // across multiple hosts based on both the ``runtime_key`` and + // :ref:`x-request-id`. + // * Useful for tracking related requests across a distributed system. + // + // - When set to ``true`` or :ref:`x-request-id` is missing: + // + // * Sampling decisions are made randomly based only on the ``runtime_key``. + // * Useful in complex filter configurations (like nested + // :ref:`AndFilter`/ + // :ref:`OrFilter` blocks) where independent probability + // calculations are desired. + // * Can be used to implement logging kill switches with predictable probability distributions. + // bool use_independent_randomness = 3; } @@ -257,6 +263,7 @@ message ResponseFlagFilter { in: "DF" in: "DO" in: "DR" + in: "UDO" } } }]; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto b/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto index 94868f13432..7b862c1021a 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto @@ -16,6 +16,7 @@ import "envoy/config/metrics/v3/stats.proto"; import "envoy/config/overload/v3/overload.proto"; import "envoy/config/trace/v3/http_tracer.proto"; import "envoy/extensions/transport_sockets/tls/v3/secret.proto"; +import "envoy/type/matcher/v3/string.proto"; import "envoy/type/v3/percent.proto"; import "google/protobuf/duration.proto"; @@ -41,7 +42,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // ` for more detail. // Bootstrap :ref:`configuration overview `. -// [#next-free-field: 42] +// [#next-free-field: 43] message Bootstrap { option (udpa.annotations.versioning).previous_message_type = "envoy.config.bootstrap.v2.Bootstrap"; @@ -57,9 +58,7 @@ message Bootstrap { // If a network based configuration source is specified for :ref:`cds_config // `, it's necessary // to have some initial cluster definitions available to allow Envoy to know - // how to speak to the management server. These cluster definitions may not - // use :ref:`EDS ` (i.e. they should be static - // IP or DNS-based). + // how to speak to the management server. repeated cluster.v3.Cluster clusters = 2; // These static secrets can be used by :ref:`SdsSecretConfig @@ -78,7 +77,7 @@ message Bootstrap { // :ref:`LDS ` configuration source. core.v3.ConfigSource lds_config = 1; - // xdstp:// resource locator for listener collection. + // ``xdstp://`` resource locator for listener collection. // [#not-implemented-hide:] string lds_resources_locator = 5; @@ -87,7 +86,7 @@ message Bootstrap { // configuration source. core.v3.ConfigSource cds_config = 2; - // xdstp:// resource locator for cluster collection. + // ``xdstp://`` resource locator for cluster collection. // [#not-implemented-hide:] string cds_resources_locator = 6; @@ -128,17 +127,19 @@ message Bootstrap { // When the flag is enabled, Envoy will lazily initialize a subset of the stats (see below). // This will save memory and CPU cycles when creating the objects that own these stats, if those // stats are never referenced throughout the lifetime of the process. However, it will incur additional - // memory overhead for these objects, and a small increase of CPU usage when a at least one of the stats + // memory overhead for these objects, and a small increase of CPU usage when at least one of the stats // is updated for the first time. + // // Groups of stats that will be lazily initialized: + // // - Cluster traffic stats: a subgroup of the :ref:`cluster statistics ` - // that are used when requests are routed to the cluster. + // that are used when requests are routed to the cluster. bool enable_deferred_creation_stats = 1; } message GrpcAsyncClientManagerConfig { // Optional field to set the expiration time for the cached gRPC client object. - // The minimal value is 5s and the default is 50s. + // The minimal value is ``5s`` and the default is ``50s``. google.protobuf.Duration max_cached_entry_idle_duration = 1 [(validate.rules).duration = {gte {seconds: 5}}]; } @@ -153,25 +154,25 @@ message Bootstrap { // A list of :ref:`Node ` field names // that will be included in the context parameters of the effective - // xdstp:// URL that is sent in a discovery request when resource + // ``xdstp://`` URL that is sent in a discovery request when resource // locators are used for LDS/CDS. Any non-string field will have its JSON // encoding set as the context parameter value, with the exception of // metadata, which will be flattened (see example below). The supported field // names are: - // - "cluster" - // - "id" - // - "locality.region" - // - "locality.sub_zone" - // - "locality.zone" - // - "metadata" - // - "user_agent_build_version.metadata" - // - "user_agent_build_version.version" - // - "user_agent_name" - // - "user_agent_version" + // - ``cluster`` + // - ``id`` + // - ``locality.region`` + // - ``locality.sub_zone`` + // - ``locality.zone`` + // - ``metadata`` + // - ``user_agent_build_version.metadata`` + // - ``user_agent_build_version.version`` + // - ``user_agent_name`` + // - ``user_agent_version`` // // The node context parameters act as a base layer dictionary for the context // parameters (i.e. more specific resource specific context parameters will - // override). Field names will be prefixed with “udpa.node.” when included in + // override). Field names will be prefixed with ````"udpa.node."```` when included in // context parameters. // // For example, if node_context_params is ``["user_agent_name", "metadata"]``, @@ -213,10 +214,10 @@ message Bootstrap { // Optional duration between flushes to configured stats sinks. For // performance reasons Envoy latches counters and only flushes counters and - // gauges at a periodic interval. If not specified the default is 5000ms (5 - // seconds). Only one of ``stats_flush_interval`` or ``stats_flush_on_admin`` + // gauges at a periodic interval. If not specified the default is ``5000ms`` (``5`` seconds). + // Only one of ``stats_flush_interval`` or ``stats_flush_on_admin`` // can be set. - // Duration must be at least 1ms and at most 5 min. + // Duration must be at least ``1ms`` and at most ``5 min``. google.protobuf.Duration stats_flush_interval = 7 [ (validate.rules).duration = { lt {seconds: 300} @@ -232,6 +233,14 @@ message Bootstrap { bool stats_flush_on_admin = 29 [(validate.rules).bool = {const: true}]; } + oneof stats_eviction { + // Optional duration to perform metric eviction. At every interval, during the stats flush + // the unused metrics are removed from the worker caches and the used metrics + // are marked as unused. Must be a multiple of the ``stats_flush_interval``. + google.protobuf.Duration stats_eviction_interval = 42 + [(validate.rules).duration = {gte {nanos: 1000000}}]; + } + // Optional watchdog configuration. // This is for a single watchdog configuration for the entire system. // Deprecated in favor of ``watchdogs`` which has finer granularity. @@ -265,23 +274,28 @@ message Bootstrap { (udpa.annotations.security).configure_for_untrusted_upstream = true ]; - // Enable :ref:`stats for event dispatcher `, defaults to false. - // Note that this records a value for each iteration of the event loop on every thread. This - // should normally be minimal overhead, but when using - // :ref:`statsd `, it will send each observed value - // over the wire individually because the statsd protocol doesn't have any way to represent a - // histogram summary. Be aware that this can be a very large volume of data. + // Enable :ref:`stats for event dispatcher `. Defaults to ``false``. + // + // .. note:: + // + // This records a value for each iteration of the event loop on every thread. This + // should normally be minimal overhead, but when using + // :ref:`statsd `, it will send each observed value + // over the wire individually because the statsd protocol doesn't have any way to represent a + // histogram summary. Be aware that this can be a very large volume of data. bool enable_dispatcher_stats = 16; - // Optional string which will be used in lieu of x-envoy in prefixing headers. + // Optional string which will be used in lieu of ``x-envoy`` in prefixing headers. // - // For example, if this string is present and set to X-Foo, then x-envoy-retry-on will be - // transformed into x-foo-retry-on etc. + // For example, if this string is present and set to ``X-Foo``, then ``x-envoy-retry-on`` will be + // transformed into ``x-foo-retry-on`` etc. // - // Note this applies to the headers Envoy will generate, the headers Envoy will sanitize, and the - // headers Envoy will trust for core code and core extensions only. Be VERY careful making - // changes to this string, especially in multi-layer Envoy deployments or deployments using - // extensions which are not upstream. + // .. note:: + // + // This applies to the headers Envoy will generate, the headers Envoy will sanitize, and the + // headers Envoy will trust for core code and core extensions only. Be VERY careful making + // changes to this string, especially in multi-layer Envoy deployments or deployments using + // extensions which are not upstream. string header_prefix = 18; // Optional proxy version which will be used to set the value of :ref:`server.version statistic @@ -289,8 +303,8 @@ message Bootstrap { // :ref:`stats sinks `. google.protobuf.UInt64Value stats_server_version_override = 19; - // Always use TCP queries instead of UDP queries for DNS lookups. - // This may be overridden on a per-cluster basis in cds_config, + // Always use ``TCP`` queries instead of ``UDP`` queries for DNS lookups. + // This may be overridden on a per-cluster basis in ``cds_config``, // when :ref:`dns_resolvers ` and // :ref:`use_tcp_for_dns_lookups ` are // specified. @@ -299,8 +313,8 @@ message Bootstrap { bool use_tcp_for_dns_lookups = 20 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - // DNS resolution configuration which includes the underlying dns resolver addresses and options. - // This may be overridden on a per-cluster basis in cds_config, when + // DNS resolution configuration which includes the underlying DNS resolver addresses and options. + // This may be overridden on a per-cluster basis in ``cds_config``, when // :ref:`dns_resolution_config ` // is specified. // This field is deprecated in favor of @@ -308,14 +322,15 @@ message Bootstrap { core.v3.DnsResolutionConfig dns_resolution_config = 30 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - // DNS resolver type configuration extension. This extension can be used to configure c-ares, apple, + // DNS resolver type configuration extension. This extension can be used to configure ``c-ares``, ``apple``, // or any other DNS resolver types and the related parameters. // For example, an object of // :ref:`CaresDnsResolverConfig ` // can be packed into this ``typed_dns_resolver_config``. This configuration replaces the // :ref:`dns_resolution_config ` // configuration. - // During the transition period when both ``dns_resolution_config`` and ``typed_dns_resolver_config`` exists, + // + // During the transition period when both ``dns_resolution_config`` and ``typed_dns_resolver_config`` exist, // when ``typed_dns_resolver_config`` is in place, Envoy will use it and ignore ``dns_resolution_config``. // When ``typed_dns_resolver_config`` is missing, the default behavior is in place. // [#extension-category: envoy.network.dns_resolver] @@ -331,9 +346,10 @@ message Bootstrap { repeated FatalAction fatal_actions = 28; // Configuration sources that will participate in - // xdstp:// URL authority resolution. The algorithm is as + // ``xdstp://`` URL authority resolution. The algorithm is as // follows: - // 1. The authority field is taken from the xdstp:// URL, call + // + // 1. The authority field is taken from the ``xdstp://`` URL, call // this ``resource_authority``. // 2. ``resource_authority`` is compared against the authorities in any peer // ``ConfigSource``. The peer ``ConfigSource`` is the configuration source @@ -349,7 +365,7 @@ message Bootstrap { // [#not-implemented-hide:] repeated core.v3.ConfigSource config_sources = 22; - // Default configuration source for xdstp:// URLs if all + // Default configuration source for ``xdstp://`` URLs if all // other resolution fails. // [#not-implemented-hide:] core.v3.ConfigSource default_config_source = 23; @@ -369,28 +385,30 @@ message Bootstrap { // allows users to customize the inline headers on-demand at Envoy startup without modifying // Envoy's source code. // - // Note that the 'set-cookie' header cannot be registered as inline header. + // .. note:: + // + // The ``set-cookie`` header cannot be registered as inline header. repeated CustomInlineHeader inline_headers = 32; - // Optional path to a file with performance tracing data created by "Perfetto" SDK in binary - // ProtoBuf format. The default value is "envoy.pftrace". + // Optional path to a file with performance tracing data created by ``Perfetto`` SDK in binary + // ProtoBuf format. The default value is ``envoy.pftrace``. string perf_tracing_file_path = 33; // Optional overriding of default regex engine. - // If the value is not specified, Google RE2 will be used by default. + // If the value is not specified, ``Google RE2`` will be used by default. // [#extension-category: envoy.regex_engines] core.v3.TypedExtensionConfig default_regex_engine = 34; // Optional XdsResourcesDelegate configuration, which allows plugging custom logic into both // fetch and load events during xDS processing. - // If a value is not specified, no XdsResourcesDelegate will be used. + // If a value is not specified, no ``XdsResourcesDelegate`` will be used. // TODO(abeyad): Add public-facing documentation. // [#not-implemented-hide:] core.v3.TypedExtensionConfig xds_delegate_extension = 35; // Optional XdsConfigTracker configuration, which allows tracking xDS responses in external components, // e.g., external tracer or monitor. It provides the process point when receive, ingest, or fail to - // process xDS resources and messages. If a value is not specified, no XdsConfigTracker will be used. + // process xDS resources and messages. If a value is not specified, no ``XdsConfigTracker`` will be used. // // .. note:: // @@ -402,14 +420,14 @@ message Bootstrap { // [#not-implemented-hide:] // This controls the type of listener manager configured for Envoy. Currently - // Envoy only supports ListenerManager for this field and Envoy Mobile - // supports ApiListenerManager. + // Envoy only supports ``ListenerManager`` for this field and Envoy Mobile + // supports ``ApiListenerManager``. core.v3.TypedExtensionConfig listener_manager = 37; // Optional application log configuration. ApplicationLogConfig application_log_config = 38; - // Optional gRPC async manager config. + // Optional gRPC async client manager config. GrpcAsyncClientManagerConfig grpc_async_client_manager_config = 40; // Optional configuration for memory allocation manager. @@ -419,7 +437,7 @@ message Bootstrap { // Administration interface :ref:`operations documentation // `. -// [#next-free-field: 7] +// [#next-free-field: 8] message Admin { option (udpa.annotations.versioning).previous_message_type = "envoy.config.bootstrap.v2.Admin"; @@ -428,14 +446,14 @@ message Admin { repeated accesslog.v3.AccessLog access_log = 5; // The path to write the access log for the administration server. If no - // access log is desired specify ‘/dev/null’. This is only required if + // access log is desired specify ``/dev/null``. This is only required if // :ref:`address ` is set. // Deprecated in favor of ``access_log`` which offers more options. string access_log_path = 1 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - // The cpu profiler output path for the administration server. If no profile - // path is specified, the default is ‘/var/log/envoy/envoy.prof’. + // The CPU profiler output path for the administration server. If no profile + // path is specified, the default is ``/var/log/envoy/envoy.prof``. string profile_path = 2; // The TCP address that the administration server will listen on. @@ -449,6 +467,21 @@ message Admin { // Indicates whether :ref:`global_downstream_max_connections ` // should apply to the admin interface or not. bool ignore_global_conn_limit = 6; + + // List of admin paths that are accessible. If not specified, all admin endpoints are accessible. + // + // When specified, only paths in this list will be accessible, all others will return ``HTTP 403 Forbidden``. + // + // Example: + // + // .. code-block:: yaml + // + // allow_paths: + // - exact: /stats + // - exact: /ready + // - prefix: /healthcheck + // + repeated type.matcher.v3.StringMatcher allow_paths = 7; } // Cluster manager :ref:`architecture overview `. @@ -485,7 +518,7 @@ message ClusterManager { OutlierDetection outlier_detection = 2; // Optional configuration used to bind newly established upstream connections. - // This may be overridden on a per-cluster basis by upstream_bind_config in the cds_config. + // This may be overridden on a per-cluster basis by ``upstream_bind_config`` in the ``cds_config``. core.v3.BindConfig upstream_bind_config = 3; // A management server endpoint to stream load stats to via @@ -496,7 +529,7 @@ message ClusterManager { // Whether the ClusterManager will create clusters on the worker threads // inline during requests. This will save memory and CPU cycles in cases where - // there are lots of inactive clusters and > 1 worker thread. + // there are lots of inactive clusters and ``> 1`` worker thread. bool enable_deferred_cluster_creation = 5; } @@ -519,12 +552,12 @@ message Watchdog { option (udpa.annotations.versioning).previous_message_type = "envoy.config.bootstrap.v2.Watchdog"; message WatchdogAction { - // The events are fired in this order: KILL, MULTIKILL, MEGAMISS, MISS. + // The events are fired in this order: ``KILL``, ``MULTIKILL``, ``MEGAMISS``, ``MISS``. // Within an event type, actions execute in the order they are configured. - // For KILL/MULTIKILL there is a default PANIC that will run after the + // For ``KILL``/``MULTIKILL`` there is a default ``PANIC`` that will run after the // registered actions and kills the process if it wasn't already killed. // It might be useful to specify several debug actions, and possibly an - // alternate FATAL action. + // alternate ``FATAL`` action. enum WatchdogEvent { UNKNOWN = 0; KILL = 1; @@ -539,46 +572,48 @@ message Watchdog { WatchdogEvent event = 2 [(validate.rules).enum = {defined_only: true}]; } - // Register actions that will fire on given WatchDog events. - // See ``WatchDogAction`` for priority of events. + // Register actions that will fire on given Watchdog events. + // See ``WatchdogAction`` for priority of events. repeated WatchdogAction actions = 7; // The duration after which Envoy counts a nonresponsive thread in the - // ``watchdog_miss`` statistic. If not specified the default is 200ms. + // ``watchdog_miss`` statistic. If not specified the default is ``200ms``. google.protobuf.Duration miss_timeout = 1; // The duration after which Envoy counts a nonresponsive thread in the - // ``watchdog_mega_miss`` statistic. If not specified the default is - // 1000ms. + // ``watchdog_mega_miss`` statistic. If not specified the default is ``1000ms``. google.protobuf.Duration megamiss_timeout = 2; // If a watched thread has been nonresponsive for this duration, assume a - // programming error and kill the entire Envoy process. Set to 0 to disable - // kill behavior. If not specified the default is 0 (disabled). + // programming error and kill the entire Envoy process. Set to ``0`` to disable + // kill behavior. If not specified the default is ``0`` (disabled). google.protobuf.Duration kill_timeout = 3; // Defines the maximum jitter used to adjust the ``kill_timeout`` if ``kill_timeout`` is // enabled. Enabling this feature would help to reduce risk of synchronized - // watchdog kill events across proxies due to external triggers. Set to 0 to - // disable. If not specified the default is 0 (disabled). + // watchdog kill events across proxies due to external triggers. Set to ``0`` to + // disable. If not specified the default is ``0`` (disabled). google.protobuf.Duration max_kill_timeout_jitter = 6 [(validate.rules).duration = {gte {}}]; - // If ``max(2, ceil(registered_threads * Fraction(*multikill_threshold*)))`` + // If ``max(2, ceil(registered_threads * Fraction(multikill_threshold)))`` // threads have been nonresponsive for at least this duration kill the entire - // Envoy process. Set to 0 to disable this behavior. If not specified the - // default is 0 (disabled). + // Envoy process. Set to ``0`` to disable this behavior. If not specified the + // default is ``0`` (disabled). google.protobuf.Duration multikill_timeout = 4; // Sets the threshold for ``multikill_timeout`` in terms of the percentage of // nonresponsive threads required for the ``multikill_timeout``. - // If not specified the default is 0. + // If not specified the default is ``0``. type.v3.Percent multikill_threshold = 5; } // Fatal actions to run while crashing. Actions can be safe (meaning they are // async-signal safe) or unsafe. We run all safe actions before we run unsafe actions. -// If using an unsafe action that could get stuck or deadlock, it important to -// have an out of band system to terminate the process. +// +// .. note:: +// +// If using an unsafe action that could get stuck or deadlock, it is important to +// have an out of band system to terminate the process. // // The interface for the extension is ``Envoy::Server::Configuration::FatalAction``. // ``FatalAction`` extensions live in the ``envoy.extensions.fatal_actions`` API @@ -661,7 +696,7 @@ message RuntimeLayer { option (udpa.annotations.versioning).previous_message_type = "envoy.config.bootstrap.v2.RuntimeLayer.RtdsLayer"; - // Resource to subscribe to at ``rtds_config`` for the RTDS layer. + // Resource to subscribe to at the ``rtds_config`` for the RTDS layer. string name = 1; // RTDS configuration source. @@ -702,11 +737,11 @@ message LayeredRuntime { // Used to specify the header that needs to be registered as an inline header. // // If request or response contain multiple headers with the same name and the header -// name is registered as an inline header. Then multiple headers will be folded +// name is registered as an inline header, then multiple headers will be folded // into one, and multiple header values will be concatenated by a suitable delimiter. // The delimiter is generally a comma. // -// For example, if 'foo' is registered as an inline header, and the headers contains +// For example, if ``foo`` is registered as an inline header, and the headers contain // the following two headers: // // .. code-block:: text @@ -746,6 +781,6 @@ message MemoryAllocatorManager { // Interval in milliseconds for memory releasing. If specified, during every // interval Envoy will try to release ``bytes_to_release`` of free memory back to operating system for reuse. - // Defaults to 1000 milliseconds. + // Defaults to ``1000`` milliseconds. google.protobuf.Duration memory_release_interval = 2; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto index 0074e63dff6..192409096af 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto @@ -22,6 +22,7 @@ import "google/protobuf/struct.proto"; import "google/protobuf/wrappers.proto"; import "xds/core/v3/collection_entry.proto"; +import "xds/type/matcher/v3/matcher.proto"; import "envoy/annotations/deprecation.proto"; import "udpa/annotations/migrate.proto"; @@ -45,7 +46,7 @@ message ClusterCollection { } // Configuration for a single upstream cluster. -// [#next-free-field: 57] +// [#next-free-field: 60] message Cluster { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Cluster"; @@ -652,9 +653,10 @@ message Cluster { // If this is not set, we default to a merge window of 1000ms. To disable it, set the merge // window to 0. // - // Note: merging does not apply to cluster membership changes (e.g.: adds/removes); this is - // because merging those updates isn't currently safe. See - // https://github.com/envoyproxy/envoy/pull/3941. + // .. note:: + // Merging does not apply to cluster membership changes (e.g.: adds/removes); this is + // because merging those updates isn't currently safe. See + // https://github.com/envoyproxy/envoy/pull/3941. google.protobuf.Duration update_merge_window = 4; // If set to true, Envoy will :ref:`exclude ` new hosts @@ -746,6 +748,9 @@ message Cluster { // If both this and preconnect_ratio are set, Envoy will make sure both predicted needs are met, // basically preconnecting max(predictive-preconnect, per-upstream-preconnect), for each // upstream. + // + // This is limited somewhat arbitrarily to 3 because preconnecting too aggressively can + // harm latency more than the preconnecting helps. google.protobuf.DoubleValue predictive_preconnect_ratio = 2 [(validate.rules).double = {lte: 3.0 gte: 1.0}]; } @@ -754,13 +759,13 @@ message Cluster { reserved "hosts", "tls_context", "extension_protocol_options"; - // Configuration to use different transport sockets for different endpoints. The entry of + // Configuration to use different transport sockets for different endpoints. The entry of // ``envoy.transport_socket_match`` in the :ref:`LbEndpoint.Metadata // ` is used to match against the // transport sockets as they appear in the list. If a match is not found, the search continues in // :ref:`LocalityLbEndpoints.Metadata - // `. The first :ref:`match - // ` is used. For example, with + // `. The first :ref:`match + // ` is used. For example, with // the following match // // .. code-block:: yaml @@ -808,6 +813,41 @@ message Cluster { // [#comment:TODO(incfly): add a detailed architecture doc on intended usage.] repeated TransportSocketMatch transport_socket_matches = 43; + // Optional matcher that selects a transport socket from + // :ref:`transport_socket_matches `. + // + // This matcher uses the generic xDS matcher framework to select a named transport socket + // based on various inputs available at transport socket selection time. + // + // Supported matching inputs: + // + // * ``endpoint_metadata``: Extract values from the selected endpoint's metadata. + // * ``locality_metadata``: Extract values from the endpoint's locality metadata. + // * ``transport_socket_filter_state``: Extract values from filter state that was explicitly shared from + // downstream to upstream via ``TransportSocketOptions``. This enables flexible + // downstream-connection-based matching, such as: + // + // - Network namespace matching. + // - Custom connection attributes. + // - Any data explicitly passed via filter state. + // + // .. note:: + // Filter state sharing follows the same pattern as tunneling in Envoy. Filters must explicitly + // share data by setting filter state with the appropriate sharing mode. The filter state is + // then accessible via the ``transport_socket_filter_state`` input during transport socket selection. + // + // If this field is set, it takes precedence over legacy metadata-based selection + // performed by :ref:`transport_socket_matches + // ` alone. + // If the matcher does not yield a match, Envoy uses the default transport socket + // configured for the cluster. + // + // When using this field, each entry in + // :ref:`transport_socket_matches ` + // must have a unique ``name``. The matcher outcome is expected to reference one of + // these names. + xds.type.matcher.v3.Matcher transport_socket_matcher = 59; + // Supplies the name of the cluster which must be unique across all clusters. // The cluster name is used when emitting // :ref:`statistics ` if :ref:`alt_stat_name @@ -816,12 +856,14 @@ message Cluster { string name = 1 [(validate.rules).string = {min_len: 1}]; // An optional alternative to the cluster name to be used for observability. This name is used - // emitting stats for the cluster and access logging the cluster name. This will appear as + // for emitting stats for the cluster and access logging the cluster name. This will appear as // additional information in configuration dumps of a cluster's current status as // :ref:`observability_name ` - // and as an additional tag "upstream_cluster.name" while tracing. Note: Any ``:`` in the name - // will be converted to ``_`` when emitting statistics. This should not be confused with - // :ref:`Router Filter Header `. + // and as an additional tag "upstream_cluster.name" while tracing. + // + // .. note:: + // Any ``:`` in the name will be converted to ``_`` when emitting statistics. This should not be confused with + // :ref:`Router Filter Header `. string alt_stat_name = 28 [(udpa.annotations.field_migrate).rename = "observability_name"]; oneof cluster_discovery_type { @@ -942,6 +984,7 @@ message Cluster { // "envoy.filters.network.thrift_proxy". See the extension's documentation for details on // specific options. // [#next-major-version: make this a list of typed extensions.] + // [#extension-category: envoy.upstream_options] map typed_extension_protocol_options = 36; // If the DNS refresh rate is specified and the cluster type is either @@ -953,8 +996,34 @@ message Cluster { // :ref:`STRICT_DNS` // and :ref:`LOGICAL_DNS` // this setting is ignored. - google.protobuf.Duration dns_refresh_rate = 16 - [(validate.rules).duration = {gt {nanos: 1000000}}]; + // This field is deprecated in favor of using the :ref:`cluster_type` + // extension point and configuring it with :ref:`DnsCluster`. + // If :ref:`cluster_type` is configured with + // :ref:`DnsCluster`, this field will be ignored. + google.protobuf.Duration dns_refresh_rate = 16 [ + deprecated = true, + (validate.rules).duration = {gt {nanos: 1000000}}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; + + // DNS jitter can be optionally specified if the cluster type is either + // :ref:`STRICT_DNS`, + // or :ref:`LOGICAL_DNS`. + // DNS jitter causes the cluster to refresh DNS entries later by a random amount of time to avoid a + // stampede of DNS requests. This value sets the upper bound (exclusive) for the random amount. + // There will be no jitter if this value is omitted. For cluster types other than + // :ref:`STRICT_DNS` + // and :ref:`LOGICAL_DNS` + // this setting is ignored. + // This field is deprecated in favor of using the :ref:`cluster_type` + // extension point and configuring it with :ref:`DnsCluster`. + // If :ref:`cluster_type` is configured with + // :ref:`DnsCluster`, this field will be ignored. + google.protobuf.Duration dns_jitter = 58 [ + deprecated = true, + (validate.rules).duration = {gte {}}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; // If the DNS failure refresh rate is specified and the cluster type is either // :ref:`STRICT_DNS`, @@ -964,16 +1033,31 @@ message Cluster { // other than :ref:`STRICT_DNS` and // :ref:`LOGICAL_DNS` this setting is // ignored. - RefreshRate dns_failure_refresh_rate = 44; + // This field is deprecated in favor of using the :ref:`cluster_type` + // extension point and configuring it with :ref:`DnsCluster`. + // If :ref:`cluster_type` is configured with + // :ref:`DnsCluster`, this field will be ignored. + RefreshRate dns_failure_refresh_rate = 44 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Optional configuration for setting cluster's DNS refresh rate. If the value is set to true, // cluster's DNS refresh rate will be set to resource record's TTL which comes from DNS // resolution. - bool respect_dns_ttl = 39; + // This field is deprecated in favor of using the :ref:`cluster_type` + // extension point and configuring it with :ref:`DnsCluster`. + // If :ref:`cluster_type` is configured with + // :ref:`DnsCluster`, this field will be ignored. + bool respect_dns_ttl = 39 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // The DNS IP address resolution policy. If this setting is not specified, the // value defaults to // :ref:`AUTO`. + // For logical and strict dns cluster, this field is deprecated in favor of using the + // :ref:`cluster_type` + // extension point and configuring it with :ref:`DnsCluster`. + // If :ref:`cluster_type` is configured with + // :ref:`DnsCluster`, this field will be ignored. DnsLookupFamily dns_lookup_family = 17 [(validate.rules).enum = {defined_only: true}]; // If DNS resolvers are specified and the cluster type is either @@ -1013,6 +1097,9 @@ message Cluster { // During the transition period when both ``dns_resolution_config`` and ``typed_dns_resolver_config`` exists, // when ``typed_dns_resolver_config`` is in place, Envoy will use it and ignore ``dns_resolution_config``. // When ``typed_dns_resolver_config`` is missing, the default behavior is in place. + // Also note that this field is deprecated for logical dns and strict dns clusters and will be ignored when + // :ref:`cluster_type` is configured with + // :ref:`DnsCluster`. // [#extension-category: envoy.network.dns_resolver] core.v3.TypedExtensionConfig typed_dns_resolver_config = 55; @@ -1151,6 +1238,23 @@ message Cluster { // from the LRS stream here.] core.v3.ConfigSource lrs_server = 42; + // A list of metric names from :ref:`ORCA load reports ` to propagate to LRS. + // + // If not specified, then ORCA load reports will not be propagated to LRS. + // + // For map fields in the ORCA proto, the string will be of the form ``.``. + // For example, the string ``named_metrics.foo`` will mean to look for the key ``foo`` in the ORCA + // :ref:`named_metrics ` field. + // + // The special map key ``*`` means to report all entries in the map (e.g., ``named_metrics.*`` means to + // report all entries in the ORCA named_metrics field). Note that this should be used only with trusted + // backends. + // + // The metric names in LRS will follow the same semantics as this field. In other words, if this field + // contains ``named_metrics.foo``, then the LRS load report will include the data with that same string + // as the key. + repeated string lrs_report_endpoint_metrics = 57; + // If track_timeout_budgets is true, the :ref:`timeout budget histograms // ` will be published for each // request. These show what percentage of a request's per try and global timeout was used. A value @@ -1283,7 +1387,7 @@ message TrackClusterStats { // If request_response_sizes is true, then the :ref:`histograms // ` tracking header and body sizes - // of requests and responses will be published. + // of requests and responses will be published. Additionally, number of headers in the requests and responses will be tracked. bool request_response_sizes = 2; // If true, some stats will be emitted per-endpoint, similar to the stats in admin ``/clusters`` diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/common/matcher/v3/matcher.proto b/xds/third_party/envoy/src/main/proto/envoy/config/common/matcher/v3/matcher.proto new file mode 100644 index 00000000000..9b189d1aa77 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/config/common/matcher/v3/matcher.proto @@ -0,0 +1,239 @@ +syntax = "proto3"; + +package envoy.config.common.matcher.v3; + +import "envoy/config/core/v3/extension.proto"; +import "envoy/config/route/v3/route_components.proto"; +import "envoy/type/matcher/v3/string.proto"; + +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.config.common.matcher.v3"; +option java_outer_classname = "MatcherProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/common/matcher/v3;matcherv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Unified Matcher API] + +// A matcher, which may traverse a matching tree in order to result in a match action. +// During matching, the tree will be traversed until a match is found, or if no match +// is found the action specified by the most specific on_no_match will be evaluated. +// As an on_no_match might result in another matching tree being evaluated, this process +// might repeat several times until the final OnMatch (or no match) is decided. +// +// .. note:: +// Please use the syntactically equivalent :ref:`matching API ` +message Matcher { + // What to do if a match is successful. + message OnMatch { + oneof on_match { + option (validate.required) = true; + + // Nested matcher to evaluate. + // If the nested matcher does not match and does not specify + // on_no_match, then this matcher is considered not to have + // matched, even if a predicate at this level or above returned + // true. + Matcher matcher = 1; + + // Protocol-specific action to take. + core.v3.TypedExtensionConfig action = 2; + } + + // If true, the action will be taken but the caller will behave as if no + // match was found. This applies both to actions directly encoded in the + // action field and to actions returned from a nested matcher tree in the + // matcher field. A subsequent matcher on_no_match action will be used + // instead. + // + // This field is not supported in all contexts in which the matcher API is + // used. If this field is set in a context in which it's not supported, + // the resource will be rejected. + bool keep_matching = 3; + } + + // A linear list of field matchers. + // The field matchers are evaluated in order, and the first match + // wins. + message MatcherList { + // Predicate to determine if a match is successful. + message Predicate { + // Predicate for a single input field. + message SinglePredicate { + // Protocol-specific specification of input field to match on. + // [#extension-category: envoy.matching.common_inputs] + core.v3.TypedExtensionConfig input = 1 [(validate.rules).message = {required: true}]; + + oneof matcher { + option (validate.required) = true; + + // Built-in string matcher. + type.matcher.v3.StringMatcher value_match = 2; + + // Extension for custom matching logic. + // [#extension-category: envoy.matching.input_matchers] + core.v3.TypedExtensionConfig custom_match = 3; + } + } + + // A list of two or more matchers. Used to allow using a list within a oneof. + message PredicateList { + repeated Predicate predicate = 1 [(validate.rules).repeated = {min_items: 2}]; + } + + oneof match_type { + option (validate.required) = true; + + // A single predicate to evaluate. + SinglePredicate single_predicate = 1; + + // A list of predicates to be OR-ed together. + PredicateList or_matcher = 2; + + // A list of predicates to be AND-ed together. + PredicateList and_matcher = 3; + + // The inverse of a predicate + Predicate not_matcher = 4; + } + } + + // An individual matcher. + message FieldMatcher { + // Determines if the match succeeds. + Predicate predicate = 1 [(validate.rules).message = {required: true}]; + + // What to do if the match succeeds. + OnMatch on_match = 2 [(validate.rules).message = {required: true}]; + } + + // A list of matchers. First match wins. + repeated FieldMatcher matchers = 1 [(validate.rules).repeated = {min_items: 1}]; + } + + message MatcherTree { + // A map of configured matchers. Used to allow using a map within a oneof. + message MatchMap { + map map = 1 [(validate.rules).map = {min_pairs: 1}]; + } + + // Protocol-specific specification of input field to match on. + core.v3.TypedExtensionConfig input = 1 [(validate.rules).message = {required: true}]; + + // Exact or prefix match maps in which to look up the input value. + // If the lookup succeeds, the match is considered successful, and + // the corresponding OnMatch is used. + oneof tree_type { + option (validate.required) = true; + + MatchMap exact_match_map = 2; + + // Longest matching prefix wins. + MatchMap prefix_match_map = 3; + + // Extension for custom matching logic. + core.v3.TypedExtensionConfig custom_match = 4; + } + } + + oneof matcher_type { + option (validate.required) = true; + + // A linear list of matchers to evaluate. + MatcherList matcher_list = 1; + + // A match tree to evaluate. + MatcherTree matcher_tree = 2; + } + + // Optional ``OnMatch`` to use if the matcher failed. + // If specified, the ``OnMatch`` is used, and the matcher is considered + // to have matched. + // If not specified, the matcher is considered not to have matched. + OnMatch on_no_match = 3; +} + +// Match configuration. This is a recursive structure which allows complex nested match +// configurations to be built using various logical operators. +// [#next-free-field: 11] +message MatchPredicate { + // A set of match configurations used for logical operations. + message MatchSet { + // The list of rules that make up the set. + repeated MatchPredicate rules = 1 [(validate.rules).repeated = {min_items: 2}]; + } + + oneof rule { + option (validate.required) = true; + + // A set that describes a logical OR. If any member of the set matches, the match configuration + // matches. + MatchSet or_match = 1; + + // A set that describes a logical AND. If all members of the set match, the match configuration + // matches. + MatchSet and_match = 2; + + // A negation match. The match configuration will match if the negated match condition matches. + MatchPredicate not_match = 3; + + // The match configuration will always match. + bool any_match = 4 [(validate.rules).bool = {const: true}]; + + // HTTP request headers match configuration. + HttpHeadersMatch http_request_headers_match = 5; + + // HTTP request trailers match configuration. + HttpHeadersMatch http_request_trailers_match = 6; + + // HTTP response headers match configuration. + HttpHeadersMatch http_response_headers_match = 7; + + // HTTP response trailers match configuration. + HttpHeadersMatch http_response_trailers_match = 8; + + // HTTP request generic body match configuration. + HttpGenericBodyMatch http_request_generic_body_match = 9; + + // HTTP response generic body match configuration. + HttpGenericBodyMatch http_response_generic_body_match = 10; + } +} + +// HTTP headers match configuration. +message HttpHeadersMatch { + // HTTP headers to match. + repeated route.v3.HeaderMatcher headers = 1; +} + +// HTTP generic body match configuration. +// List of text strings and hex strings to be located in HTTP body. +// All specified strings must be found in the HTTP body for positive match. +// The search may be limited to specified number of bytes from the body start. +// +// .. attention:: +// +// Searching for patterns in HTTP body is potentially CPU-intensive. For each specified pattern, HTTP body is scanned byte by byte to find a match. +// If multiple patterns are specified, the process is repeated for each pattern. If location of a pattern is known, ``bytes_limit`` should be specified +// to scan only part of the HTTP body. +message HttpGenericBodyMatch { + message GenericTextMatch { + oneof rule { + option (validate.required) = true; + + // Text string to be located in HTTP body. + string string_match = 1 [(validate.rules).string = {min_len: 1}]; + + // Sequence of bytes to be located in HTTP body. + bytes binary_match = 2 [(validate.rules).bytes = {min_len: 1}]; + } + } + + // Limits search to specified number of bytes - default zero (no limit - match entire captured buffer). + uint32 bytes_limit = 1; + + // List of patterns to match. + repeated GenericTextMatch patterns = 2 [(validate.rules).repeated = {min_items: 1}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/common/mutation_rules/v3/mutation_rules.proto b/xds/third_party/envoy/src/main/proto/envoy/config/common/mutation_rules/v3/mutation_rules.proto new file mode 100644 index 00000000000..c015db21431 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/config/common/mutation_rules/v3/mutation_rules.proto @@ -0,0 +1,113 @@ +syntax = "proto3"; + +package envoy.config.common.mutation_rules.v3; + +import "envoy/config/core/v3/base.proto"; +import "envoy/type/matcher/v3/regex.proto"; +import "envoy/type/matcher/v3/string.proto"; + +import "google/protobuf/wrappers.proto"; + +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.config.common.mutation_rules.v3"; +option java_outer_classname = "MutationRulesProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/common/mutation_rules/v3;mutation_rulesv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Header mutation rules] + +// The HeaderMutationRules structure specifies what headers may be +// manipulated by a processing filter. This set of rules makes it +// possible to control which modifications a filter may make. +// +// By default, an external processing server may add, modify, or remove +// any header except for an "Envoy internal" header (which is typically +// denoted by an x-envoy prefix) or specific headers that may affect +// further filter processing: +// +// * ``host`` +// * ``:authority`` +// * ``:scheme`` +// * ``:method`` +// +// Every attempt to add, change, append, or remove a header will be +// tested against the rules here. Disallowed header mutations will be +// ignored unless ``disallow_is_error`` is set to true. +// +// Attempts to remove headers are further constrained -- regardless of the +// settings, system-defined headers (that start with ``:``) and the ``host`` +// header may never be removed. +// +// In addition, a counter will be incremented whenever a mutation is +// rejected. In the ext_proc filter, that counter is named +// ``rejected_header_mutations``. +// [#next-free-field: 8] +message HeaderMutationRules { + // By default, certain headers that could affect processing of subsequent + // filters or request routing cannot be modified. These headers are + // ``host``, ``:authority``, ``:scheme``, and ``:method``. Setting this parameter + // to true allows these headers to be modified as well. + google.protobuf.BoolValue allow_all_routing = 1; + + // If true, allow modification of envoy internal headers. By default, these + // start with ``x-envoy`` but this may be overridden in the ``Bootstrap`` + // configuration using the + // :ref:`header_prefix ` + // field. Default is false. + google.protobuf.BoolValue allow_envoy = 2; + + // If true, prevent modification of any system header, defined as a header + // that starts with a ``:`` character, regardless of any other settings. + // A processing server may still override the ``:status`` of an HTTP response + // using an ``ImmediateResponse`` message. Default is false. + google.protobuf.BoolValue disallow_system = 3; + + // If true, prevent modifications of all header values, regardless of any + // other settings. A processing server may still override the ``:status`` + // of an HTTP response using an ``ImmediateResponse`` message. Default is false. + google.protobuf.BoolValue disallow_all = 4; + + // If set, specifically allow any header that matches this regular + // expression. This overrides all other settings except for + // ``disallow_expression``. + type.matcher.v3.RegexMatcher allow_expression = 5; + + // If set, specifically disallow any header that matches this regular + // expression regardless of any other settings. + type.matcher.v3.RegexMatcher disallow_expression = 6; + + // If true, and if the rules in this list cause a header mutation to be + // disallowed, then the filter using this configuration will terminate the + // request with a 500 error. In addition, regardless of the setting of this + // parameter, any attempt to set, add, or modify a disallowed header will + // cause the ``rejected_header_mutations`` counter to be incremented. + // Default is false. + google.protobuf.BoolValue disallow_is_error = 7; +} + +// The HeaderMutation structure specifies an action that may be taken on HTTP +// headers. +message HeaderMutation { + message RemoveOnMatch { + // A string matcher that will be applied to the header key. If the header key + // matches, the header will be removed. + type.matcher.v3.StringMatcher key_matcher = 1 [(validate.rules).message = {required: true}]; + } + + oneof action { + option (validate.required) = true; + + // Remove the specified header if it exists. + string remove = 1 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; + + // Append new header by the specified HeaderValueOption. + core.v3.HeaderValueOption append = 2; + + // Remove the header if the key matches the specified string matcher. + RemoveOnMatch remove_on_match = 3; + } +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/address.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/address.proto index d8d47882655..17a68269e34 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/address.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/address.proto @@ -50,7 +50,7 @@ message EnvoyInternalAddress { string endpoint_id = 2; } -// [#next-free-field: 7] +// [#next-free-field: 8] message SocketAddress { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.SocketAddress"; @@ -97,6 +97,17 @@ message SocketAddress { // allow both IPv4 and IPv6 connections, with peer IPv4 addresses mapped into // IPv6 space as ``::FFFF:``. bool ipv4_compat = 6; + + // Filepath that specifies the Linux network namespace this socket will be created in (see ``man 7 + // network_namespaces``). If this field is set, Envoy will create the socket in the specified + // network namespace. + // + // .. note:: + // Setting this parameter requires Envoy to run with the ``CAP_NET_ADMIN`` capability. + // + // .. attention:: + // Network namespaces are only configurable on Linux. Otherwise, this field has no effect. + string network_namespace_filepath = 7; } message TcpKeepalive { @@ -104,16 +115,18 @@ message TcpKeepalive { // Maximum number of keepalive probes to send without response before deciding // the connection is dead. Default is to use the OS level configuration (unless - // overridden, Linux defaults to 9.) + // overridden, Linux defaults to 9.) Setting this to ``0`` disables TCP keepalive. google.protobuf.UInt32Value keepalive_probes = 1; // The number of seconds a connection needs to be idle before keep-alive probes // start being sent. Default is to use the OS level configuration (unless - // overridden, Linux defaults to 7200s (i.e., 2 hours.) + // overridden, Linux defaults to 7200s (i.e., 2 hours.) Setting this to ``0`` disables + // TCP keepalive. google.protobuf.UInt32Value keepalive_time = 2; // The number of seconds between keep-alive probes. Default is to use the OS - // level configuration (unless overridden, Linux defaults to 75s.) + // level configuration (unless overridden, Linux defaults to 75s.) Setting this to + // ``0`` disables TCP keepalive. google.protobuf.UInt32Value keepalive_interval = 3; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/base.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/base.proto index df91565d0a7..978f365d5f9 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/base.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/base.proto @@ -266,7 +266,7 @@ message RuntimeUInt32 { uint32 default_value = 2; // Runtime key to get value for comparison. This value is used if defined. - string runtime_key = 3 [(validate.rules).string = {min_len: 1}]; + string runtime_key = 3; } // Runtime derived percentage with a default when not specified. @@ -275,7 +275,7 @@ message RuntimePercent { type.v3.Percent default_value = 1; // Runtime key to get value for comparison. This value is used if defined. - string runtime_key = 2 [(validate.rules).string = {min_len: 1}]; + string runtime_key = 2; } // Runtime derived double with a default when not specified. @@ -286,7 +286,7 @@ message RuntimeDouble { double default_value = 1; // Runtime key to get value for comparison. This value is used if defined. - string runtime_key = 2 [(validate.rules).string = {min_len: 1}]; + string runtime_key = 2; } // Runtime derived bool with a default when not specified. @@ -300,15 +300,34 @@ message RuntimeFeatureFlag { // Runtime key to get value for comparison. This value is used if defined. The boolean value must // be represented via its // `canonical JSON encoding `_. - string runtime_key = 2 [(validate.rules).string = {min_len: 1}]; + string runtime_key = 2; } +// Please use :ref:`KeyValuePair ` instead. +// [#not-implemented-hide:] message KeyValue { + // The key of the key/value pair. + string key = 1 [ + deprecated = true, + (validate.rules).string = {min_len: 1 max_bytes: 16384}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; + + // The value of the key/value pair. + // + // The ``bytes`` type is used. This means if JSON or YAML is used to to represent the + // configuration, the value must be base64 encoded. This is unfriendly for users in most + // use scenarios of this message. + // + bytes value = 2 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; +} + +message KeyValuePair { // The key of the key/value pair. string key = 1 [(validate.rules).string = {min_len: 1 max_bytes: 16384}]; // The value of the key/value pair. - bytes value = 2; + google.protobuf.Value value = 2; } // Key/value pair plus option to control append behavior. This is used to specify @@ -339,8 +358,18 @@ message KeyValueAppend { OVERWRITE_IF_EXISTS = 3; } - // Key/value pair entry that this option to append or overwrite. - KeyValue entry = 1 [(validate.rules).message = {required: true}]; + // The single key/value pair record to be appended or overridden. This field must be set. + KeyValuePair record = 3; + + // Key/value pair entry that this option to append or overwrite. This field is deprecated + // and please use :ref:`record ` + // as replacement. + // [#not-implemented-hide:] + KeyValue entry = 1 [ + deprecated = true, + (validate.rules).message = {skip: true}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; // Describes the action taken to append/overwrite the given value for an existing // key or to only add this key if it's absent. @@ -349,10 +378,12 @@ message KeyValueAppend { // Key/value pair to append or remove. message KeyValueMutation { - // Key/value pair to append or overwrite. Only one of ``append`` or ``remove`` can be set. + // Key/value pair to append or overwrite. Only one of ``append`` or ``remove`` can be set or + // the configuration will be rejected. KeyValueAppend append = 1; - // Key to remove. Only one of ``append`` or ``remove`` can be set. + // Key to remove. Only one of ``append`` or ``remove`` can be set or the configuration will be + // rejected. string remove = 2 [(validate.rules).string = {max_bytes: 16384}]; } @@ -453,6 +484,7 @@ message HeaderValueOption { message HeaderMap { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.HeaderMap"; + // A list of header names and their values. repeated HeaderValue headers = 1; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/cel.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/cel.proto new file mode 100644 index 00000000000..940a66d0b10 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/cel.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package envoy.config.core.v3; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.config.core.v3"; +option java_outer_classname = "CelProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: CEL Expression Configuration] + +// CEL expression evaluation configuration. +// These options control the behavior of the Common Expression Language runtime for +// individual CEL expressions. +message CelExpressionConfig { + // Enable string conversion functions for CEL expressions. When enabled, CEL expressions + // can convert values to strings using the ``string()`` function. + // + // .. attention:: + // + // This option is disabled by default to avoid unbounded memory allocation. + // CEL evaluation cost is typically bounded by the expression size, but converting + // arbitrary values (e.g., large messages, lists, or maps) to strings may allocate + // memory proportional to input data size, which can be unbounded and lead to + // memory exhaustion. + bool enable_string_conversion = 1; + + // Enable string concatenation for CEL expressions. When enabled, CEL expressions + // can concatenate strings using the ``+`` operator. + // + // .. attention:: + // + // This option is disabled by default to avoid unbounded memory allocation. + // While CEL normally bounds evaluation by expression size, enabling string + // concatenation allows building outputs whose size depends on input data, + // potentially causing large intermediate allocations and memory exhaustion. + bool enable_string_concat = 2; + + // Enable string manipulation functions for CEL expressions. When enabled, CEL + // expressions can use additional string functions: + // + // * ``replace(old, new)`` - Replaces all occurrences of ``old`` with ``new``. + // * ``split(separator)`` - Splits a string into a list of substrings. + // * ``lowerAscii()`` - Converts ASCII characters to lowercase. + // * ``upperAscii()`` - Converts ASCII characters to uppercase. + // + // .. note:: + // + // Standard CEL string functions like ``contains()``, ``startsWith()``, and + // ``endsWith()`` are always available regardless of this setting. + // + // .. attention:: + // + // This option is disabled by default to avoid unbounded memory allocation. + // Although CEL generally bounds evaluation by expression size, functions such as + // ``replace``, ``split``, ``lowerAscii()``, and ``upperAscii()`` can allocate memory + // proportional to input data size. Under adversarial inputs this can lead to + // unbounded allocations and memory exhaustion. + bool enable_string_functions = 3; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/config_source.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/config_source.proto index f0effd99e45..430562aa5bd 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/config_source.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/config_source.proto @@ -276,7 +276,8 @@ message ExtensionConfigSource { // to be supplied. bool apply_default_config_without_warming = 3; - // A set of permitted extension type URLs. Extension configuration updates are rejected - // if they do not match any type URL in the set. + // A set of permitted extension type URLs for the type encoded inside of the + // :ref:`TypedExtensionConfig `. Extension + // configuration updates are rejected if they do not match any type URL in the set. repeated string type_urls = 4 [(validate.rules).repeated = {min_items: 1}]; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/grpc_service.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/grpc_service.proto index 5fd7921a806..9c44006b2a9 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/grpc_service.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/grpc_service.proto @@ -45,10 +45,20 @@ message GrpcService { [(validate.rules).string = {min_len: 0 max_bytes: 16384 well_known_regex: HTTP_HEADER_VALUE strict: false}]; - // Indicates the retry policy for re-establishing the gRPC stream - // This field is optional. If max interval is not provided, it will be set to ten times the provided base interval. - // Currently only supported for xDS gRPC streams. - // If not set, xDS gRPC streams default base interval:500ms, maximum interval:30s will be applied. + // Specifies the retry backoff policy for re-establishing long‑lived xDS gRPC streams. + // + // This field is optional. If ``retry_back_off.max_interval`` is not provided, it will be set to + // ten times the configured ``retry_back_off.base_interval``. + // + // .. note:: + // + // This field is only honored for management‑plane xDS gRPC streams created from + // :ref:`ApiConfigSource ` that use + // ``envoy_grpc``. Data‑plane gRPC clients (for example external authorization or external + // processing filters) must use :ref:`GrpcService.retry_policy + // ` instead. + // + // If not set, xDS gRPC streams default to a base interval of 500ms and a maximum interval of 30s. RetryPolicy retry_policy = 3; // Maximum gRPC message size that is allowed to be received. @@ -64,7 +74,7 @@ message GrpcService { bool skip_envoy_headers = 5; } - // [#next-free-field: 9] + // [#next-free-field: 11] message GoogleGrpc { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.GrpcService.GoogleGrpc"; @@ -249,16 +259,31 @@ message GrpcService { } // The target URI when using the `Google C++ gRPC client - // `_. SSL credentials will be supplied in - // :ref:`channel_credentials `. + // `_. string target_uri = 1 [(validate.rules).string = {min_len: 1}]; + // The channel credentials to use. See `channel credentials + // `_. + // Ignored if ``channel_credentials_plugin`` is set. ChannelCredentials channel_credentials = 2; - // A set of call credentials that can be composed with `channel credentials + // A list of channel credentials plugins. + // The data plane will iterate over the list in order and stop at the first credential type + // that it supports. This provides a mechanism for starting to use new credential types that + // are not yet supported by all data planes. + // [#not-implemented-hide:] + repeated google.protobuf.Any channel_credentials_plugin = 9; + + // The call credentials to use. See `channel credentials // `_. + // Ignored if ``call_credentials_plugin`` is set. repeated CallCredentials call_credentials = 3; + // A list of call credentials plugins. All supported plugins will be used. + // Unsupported plugin types will be ignored. + // [#not-implemented-hide:] + repeated google.protobuf.Any call_credentials_plugin = 10; + // The human readable prefix to use when emitting statistics for the gRPC // service. // @@ -314,7 +339,17 @@ message GrpcService { // `. repeated HeaderValue initial_metadata = 5; - // Optional default retry policy for streams toward the service. - // If an async stream doesn't have retry policy configured in its stream options, this retry policy is used. + // Optional default retry policy for RPCs or streams initiated toward this gRPC service. + // + // If an async stream does not have a retry policy configured in its per‑stream options, this + // policy is used as the default. + // + // .. note:: + // + // This field is only applied by Envoy gRPC (``envoy_grpc``) clients. Google gRPC + // (``google_grpc``) clients currently ignore this field. + // + // If not specified, no default retry policy is applied at the client level and retries only occur + // when explicitly configured in per‑stream options. RetryPolicy retry_policy = 6; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/health_check.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/health_check.proto index 821f042bbe6..a4ed6e91818 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/health_check.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/health_check.proto @@ -102,7 +102,8 @@ message HealthCheck { // ``/healthcheck``. string path = 2 [(validate.rules).string = {min_len: 1 well_known_regex: HTTP_HEADER_VALUE}]; - // [#not-implemented-hide:] HTTP specific payload. + // HTTP specific payload to be sent as the request body during health checking. + // If specified, the method should support a request body (POST, PUT, PATCH, etc.). Payload send = 3; // Specifies a list of HTTP expected responses to match in the first ``response_buffer_size`` bytes of the response body. @@ -161,7 +162,8 @@ message HealthCheck { type.matcher.v3.StringMatcher service_name_matcher = 11; // HTTP Method that will be used for health checking, default is "GET". - // GET, HEAD, POST, PUT, DELETE, OPTIONS, TRACE, PATCH methods are supported, but making request body is not supported. + // GET, HEAD, POST, PUT, DELETE, OPTIONS, TRACE, PATCH methods are supported. + // Request body payloads are supported for POST, PUT, PATCH, and OPTIONS methods only. // CONNECT method is disallowed because it is not appropriate for health check request. // If a non-200 response is expected by the method, it needs to be set in :ref:`expected_statuses `. RequestMethod method = 13 [(validate.rules).enum = {defined_only: true not_in: 6}]; @@ -375,13 +377,13 @@ message HealthCheck { // The default value for "healthy edge interval" is the same as the default interval. google.protobuf.Duration healthy_edge_interval = 16 [(validate.rules).duration = {gt {}}]; - // .. attention:: - // This field is deprecated in favor of the extension - // :ref:`event_logger ` and - // :ref:`event_log_path ` - // in the file sink extension. - // // Specifies the path to the :ref:`health check event log `. + // + // .. attention:: + // This field is deprecated in favor of the extension + // :ref:`event_logger ` and + // :ref:`event_log_path ` + // in the file sink extension. string event_log_path = 17 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto index e2c5863d784..63e189e689e 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.config.core.v3; import "envoy/config/core/v3/extension.proto"; +import "envoy/type/matcher/v3/string.proto"; import "envoy/type/v3/percent.proto"; import "google/protobuf/duration.proto"; @@ -30,44 +31,80 @@ message TcpProtocolOptions { } // Config for keepalive probes in a QUIC connection. -// Note that QUIC keep-alive probing packets work differently from HTTP/2 keep-alive PINGs in a sense that the probing packet -// itself doesn't timeout waiting for a probing response. Quic has a shorter idle timeout than TCP, so it doesn't rely on such probing to discover dead connections. If the peer fails to respond, the connection will idle timeout eventually. Thus, they are configured differently from :ref:`connection_keepalive `. +// +// .. note:: +// +// QUIC keep-alive probing packets work differently from HTTP/2 keep-alive PINGs in a sense that the probing packet +// itself doesn't timeout waiting for a probing response. QUIC has a shorter idle timeout than TCP, so it doesn't rely on such probing to discover dead connections. If the peer fails to respond, the connection will idle timeout eventually. Thus, they are configured differently from :ref:`connection_keepalive `. message QuicKeepAliveSettings { - // The max interval for a connection to send keep-alive probing packets (with PING or PATH_RESPONSE). The value should be smaller than :ref:`connection idle_timeout ` to prevent idle timeout while not less than 1s to avoid throttling the connection or flooding the peer with probes. + // The max interval for a connection to send keep-alive probing packets (with ``PING`` or ``PATH_RESPONSE``). The value should be smaller than :ref:`connection idle_timeout ` to prevent idle timeout while not less than ``1s`` to avoid throttling the connection or flooding the peer with probes. // // If :ref:`initial_interval ` is absent or zero, a client connection will use this value to start probing. // // If zero, disable keepalive probing. // If absent, use the QUICHE default interval to probe. - google.protobuf.Duration max_interval = 1 [(validate.rules).duration = { - lte {} - gte {seconds: 1} - }]; + google.protobuf.Duration max_interval = 1; // The interval to send the first few keep-alive probing packets to prevent connection from hitting the idle timeout. Subsequent probes will be sent, each one with an interval exponentially longer than previous one, till it reaches :ref:`max_interval `. And the probes afterwards will always use :ref:`max_interval `. // // The value should be smaller than :ref:`connection idle_timeout ` to prevent idle timeout and smaller than max_interval to take effect. // - // If absent or zero, disable keepalive probing for a server connection. For a client connection, if :ref:`max_interval ` is also zero, do not keepalive, otherwise use max_interval or QUICHE default to probe all the time. + // If absent, disable keepalive probing for a server connection. For a client connection, if :ref:`max_interval ` is zero, do not keepalive, otherwise use max_interval or QUICHE default to probe all the time. google.protobuf.Duration initial_interval = 2 [(validate.rules).duration = { lte {} - gte {seconds: 1} + gte {nanos: 1000000} }]; } // QUIC protocol options which apply to both downstream and upstream connections. -// [#next-free-field: 9] +// [#next-free-field: 12] message QuicProtocolOptions { - // Maximum number of streams that the client can negotiate per connection. 100 + // Config for QUIC connection migration across network interfaces, i.e. cellular to WIFI, upon + // network change events from the platform, i.e. the current network gets + // disconnected, or upon the QUIC detecting a bad connection. After migration, the + // connection may be on a different network other than the default network + // picked by the platform. Both iOS and Android will use a default network to interact with the internet, usually prefer unmetered network (WIFI) + // over metered ones (cellular). And users can specify which network to be used as the default. A connection on non-default network is only allowed to + // serve new requests for a certain period of time before being drained, and + // meanwhile, QUIC will try to migrate to the default network if possible. + message ConnectionMigrationSettings { + // Config for options to migrate idle connections which aren't serving any requests. + message MigrateIdleConnectionSettings { + // If idle connections are allowed to be migrated, only migrate the connection + // if it hasn't been idle for longer than this idle period. Otherwise, the + // connection will be closed instead. + // Default to 30s. + google.protobuf.Duration max_idle_time_before_migration = 1 + [(validate.rules).duration = {gte {seconds: 1}}]; + } + + // Config whether and how to migrate idle connections. + // If absent, idle connections will not be migrated but be closed upon + // migration signals. + MigrateIdleConnectionSettings migrate_idle_connections = 1; + + // After migrating to a non-default network interface, the connection will + // only be allowed to stay on that network for up to this period of time before + // being drained unless it migrates to the default network or that network + // gets picked as the default by the device by then. + // Default to 128s. + google.protobuf.Duration max_time_on_non_default_network = 2 + [(validate.rules).duration = {gte {seconds: 1}}]; + } + + // Maximum number of streams that the client can negotiate per connection. ``100`` // if not specified. google.protobuf.UInt32Value max_concurrent_streams = 1 [(validate.rules).uint32 = {gte: 1}]; // `Initial stream-level flow-control receive window // `_ size. Valid values range from - // 1 to 16777216 (2^24, maximum supported by QUICHE) and defaults to 16777216 (16 * 1024 * 1024). + // ``1`` to ``16777216`` (``2^24``, maximum supported by QUICHE) and defaults to ``16777216`` (``16 * 1024 * 1024``). + // + // .. note:: // - // NOTE: 16384 (2^14) is the minimum window size supported in Google QUIC. If configured smaller than it, we will use 16384 instead. - // QUICHE IETF Quic implementation supports 1 bytes window. We only support increasing the default window size now, so it's also the minimum. + // ``16384`` (``2^14``) is the minimum window size supported in Google QUIC. If configured smaller than it, we will use + // ``16384`` instead. QUICHE IETF QUIC implementation supports ``1`` byte window. We only support increasing the default + // window size now, so it's also the minimum. // // This field also acts as a soft limit on the number of bytes Envoy will buffer per-stream in the // QUIC stream send and receive buffers. Once the buffer reaches this pointer, watermark callbacks will fire to @@ -76,23 +113,26 @@ message QuicProtocolOptions { [(validate.rules).uint32 = {lte: 16777216 gte: 1}]; // Similar to ``initial_stream_window_size``, but for connection-level - // flow-control. Valid values rage from 1 to 25165824 (24MB, maximum supported by QUICHE) and defaults - // to 25165824 (24 * 1024 * 1024). + // flow-control. Valid values range from ``1`` to ``25165824`` (``24MB``, maximum supported by QUICHE) and defaults + // to ``25165824`` (``24 * 1024 * 1024``). + // + // .. note:: + // + // ``16384`` (``2^14``) is the minimum window size supported in Google QUIC. We only support increasing the default + // window size now, so it's also the minimum. // - // NOTE: 16384 (2^14) is the minimum window size supported in Google QUIC. We only support increasing the default - // window size now, so it's also the minimum. google.protobuf.UInt32Value initial_connection_window_size = 3 [(validate.rules).uint32 = {lte: 25165824 gte: 1}]; // The number of timeouts that can occur before port migration is triggered for QUIC clients. - // This defaults to 4. If set to 0, port migration will not occur on path degrading. - // Timeout here refers to QUIC internal path degrading timeout mechanism, such as PTO. + // This defaults to ``4``. If set to ``0``, port migration will not occur on path degrading. + // Timeout here refers to QUIC internal path degrading timeout mechanism, such as ``PTO``. // This has no effect on server sessions. google.protobuf.UInt32Value num_timeouts_to_trigger_port_migration = 4 [(validate.rules).uint32 = {lte: 5 gte: 0}]; - // Probes the peer at the configured interval to solicit traffic, i.e. ACK or PATH_RESPONSE, from the peer to push back connection idle timeout. - // If absent, use the default keepalive behavior of which a client connection sends PINGs every 15s, and a server connection doesn't do anything. + // Probes the peer at the configured interval to solicit traffic, i.e. ``ACK`` or ``PATH_RESPONSE``, from the peer to push back connection idle timeout. + // If absent, use the default keepalive behavior of which a client connection sends ``PING``s every ``15s``, and a server connection doesn't do anything. QuicKeepAliveSettings connection_keepalive = 5; // A comma-separated list of strings representing QUIC connection options defined in @@ -104,13 +144,35 @@ message QuicProtocolOptions { string client_connection_options = 7; // The duration that a QUIC connection stays idle before it closes itself. If this field is not present, QUICHE - // default 600s will be applied. + // default ``600s`` will be applied. // For internal corporate network, a long timeout is often fine. - // But for client facing network, 30s is usually a good choice. - google.protobuf.Duration idle_network_timeout = 8 [(validate.rules).duration = { - lte {seconds: 600} - gte {seconds: 1} - }]; + // But for client facing network, ``30s`` is usually a good choice. + // Do not add an upper bound here. A long idle timeout is useful for maintaining warm connections at non-front-line proxy for low QPS services. + google.protobuf.Duration idle_network_timeout = 8 + [(validate.rules).duration = {gte {seconds: 1}}]; + + // Maximum packet length for QUIC connections. It refers to the largest size of a QUIC packet that can be transmitted over the connection. + // If not specified, one of the `default values in QUICHE `_ is used. + google.protobuf.UInt64Value max_packet_length = 9; + + // A customized UDP socket and a QUIC packet writer using the socket for + // client connections. i.e. Mobile uses its own implementation to interact + // with platform socket APIs. + // If not present, the default platform-independent socket and writer will be used. + // [#extension-category: envoy.quic.client_packet_writer] + TypedExtensionConfig client_packet_writer = 10; + + // Enable QUIC `connection migration + // ` + // to a different network interface when the current network is degrading or + // has become bad. + // In order to use a different network interface other than the platform's default one, + // a customized :ref:`client_packet_writer ` needs to be configured to + // create UDP sockets on non-default networks. + // Only takes effect when runtime key ``envoy.reloadable_features.use_migration_in_quiche`` is true. + // If absent, the feature will be disabled. + // [#not-implemented-hide:] + ConnectionMigrationSettings connection_migration = 11; } message UpstreamHttpProtocolOptions { @@ -122,6 +184,9 @@ message UpstreamHttpProtocolOptions { // header when :ref:`override_auto_sni_header ` // is set, as seen by the :ref:`router filter `. // Does nothing if a filter before the http router filter sets the corresponding metadata. + // + // See :ref:`SNI configuration ` for details on how this + // interacts with other validation options. bool auto_sni = 1; // Automatic validate upstream presented certificate for new upstream connections based on the @@ -129,6 +194,9 @@ message UpstreamHttpProtocolOptions { // is set, as seen by the :ref:`router filter `. // This field is intended to be set with ``auto_sni`` field. // Does nothing if a filter before the http router filter sets the corresponding metadata. + // + // See :ref:`validation configuration ` for how this interacts with + // other validation options. bool auto_san_validation = 2; // An optional alternative to the host/authority header to be used for setting the SNI value. @@ -174,9 +242,9 @@ message AlternateProtocolsCacheOptions { // not the case. string name = 1 [(validate.rules).string = {min_len: 1}]; - // The maximum number of entries that the cache will hold. If not specified defaults to 1024. + // The maximum number of entries that the cache will hold. If not specified defaults to ``1024``. // - // .. note: + // .. note:: // // The implementation is approximate and enforced independently on each worker thread, thus // it is possible for the maximum entries in the cache to go slightly above the configured @@ -205,7 +273,7 @@ message AlternateProtocolsCacheOptions { repeated string canonical_suffixes = 5; } -// [#next-free-field: 7] +// [#next-free-field: 8] message HttpProtocolOptions { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.HttpProtocolOptions"; @@ -219,14 +287,14 @@ message HttpProtocolOptions { // Allow headers with underscores. This is the default behavior. ALLOW = 0; - // Reject client request. HTTP/1 requests are rejected with the 400 status. HTTP/2 requests - // end with the stream reset. The "httpN.requests_rejected_with_underscores_in_headers" counter + // Reject client request. HTTP/1 requests are rejected with ``HTTP 400`` status. HTTP/2 requests + // end with the stream reset. The ``httpN.requests_rejected_with_underscores_in_headers`` counter // is incremented for each rejected request. REJECT_REQUEST = 1; // Drop the client header with name containing underscores. The header is dropped before the filter chain is // invoked and as such filters will not see dropped headers. The - // "httpN.dropped_headers_with_underscores" is incremented for each dropped header. + // ``httpN.dropped_headers_with_underscores`` is incremented for each dropped header. DROP_HEADER = 2; } @@ -236,8 +304,12 @@ message HttpProtocolOptions { // downstream connection a drain sequence will occur prior to closing the connection, see // :ref:`drain_timeout // `. - // Note that request based timeouts mean that HTTP/2 PINGs will not keep the connection alive. - // If not specified, this defaults to 1 hour. To disable idle timeouts explicitly set this to 0. + // + // .. note:: + // + // Request based timeouts mean that HTTP/2 PINGs will not keep the connection alive. + // + // If not specified, this defaults to ``1 hour``. To disable idle timeouts explicitly set this to ``0``. // // .. warning:: // Disabling this timeout has a highly likelihood of yielding connection leaks due to lost TCP @@ -249,37 +321,66 @@ message HttpProtocolOptions { google.protobuf.Duration idle_timeout = 1; // The maximum duration of a connection. The duration is defined as a period since a connection - // was established. If not set, there is no max duration. When max_connection_duration is reached - // and if there are no active streams, the connection will be closed. If the connection is a - // downstream connection and there are any active streams, the drain sequence will kick-in, - // and the connection will be force-closed after the drain period. See :ref:`drain_timeout + // was established. If not set, there is no max duration. When max_connection_duration is reached, + // the drain sequence will kick-in. The connection will be closed after the drain timeout period + // if there are no active streams. See :ref:`drain_timeout // `. google.protobuf.Duration max_connection_duration = 3; - // The maximum number of headers. If unconfigured, the default - // maximum number of request headers allowed is 100. Requests that exceed this limit will receive - // a 431 response for HTTP/1.x and cause a stream reset for HTTP/2. + // The maximum number of headers (request headers if configured on HttpConnectionManager, + // response headers when configured on a cluster). + // If unconfigured, the default maximum number of headers allowed is ``100``. + // The default value for requests can be overridden by setting runtime key ``envoy.reloadable_features.max_request_headers_count``. + // The default value for responses can be overridden by setting runtime key ``envoy.reloadable_features.max_response_headers_count``. + // Downstream requests that exceed this limit will receive a ``HTTP 431`` response for HTTP/1.x and cause a stream + // reset for HTTP/2. + // Upstream responses that exceed this limit will result in a ``HTTP 502`` response. google.protobuf.UInt32Value max_headers_count = 2 [(validate.rules).uint32 = {gte: 1}]; + // The maximum size of response headers. + // If unconfigured, the default is ``60 KiB``, except for HTTP/1 response headers which have a default + // of ``80 KiB``. + // The default value can be overridden by setting runtime key ``envoy.reloadable_features.max_response_headers_size_kb``. + // Responses that exceed this limit will result in a ``HTTP 503`` response. + // In Envoy, this setting is only valid when configured on an upstream cluster, not on the + // :ref:`HTTP Connection Manager + // `. + // + // .. note:: + // + // Currently some protocol codecs impose limits on the maximum size of a single header. + // + // * HTTP/2 (when using ``nghttp2``) limits a single header to around ``100kb``. + // * HTTP/3 limits a single header to around ``1024kb``. + // + google.protobuf.UInt32Value max_response_headers_kb = 7 + [(validate.rules).uint32 = {lte: 8192 gt: 0}]; + // Total duration to keep alive an HTTP request/response stream. If the time limit is reached the stream will be // reset independent of any other timeouts. If not specified, this value is not set. google.protobuf.Duration max_stream_duration = 4; // Action to take when a client request with a header name containing underscore characters is received. - // If this setting is not specified, the value defaults to ALLOW. - // Note: upstream responses are not affected by this setting. - // Note: this only affects client headers. It does not affect headers added - // by Envoy filters and does not have any impact if added to cluster config. + // If this setting is not specified, the value defaults to ``ALLOW``. + // + // .. note:: + // + // Upstream responses are not affected by this setting. + // + // .. note:: + // + // This only affects client headers. It does not affect headers added by Envoy filters and does not have any + // impact if added to cluster config. HeadersWithUnderscoresAction headers_with_underscores_action = 5; // Optional maximum requests for both upstream and downstream connections. // If not specified, there is no limit. - // Setting this parameter to 1 will effectively disable keep alive. + // Setting this parameter to ``1`` will effectively disable keep alive. // For HTTP/2 and HTTP/3, due to concurrent stream processing, the limit is approximate. google.protobuf.UInt32Value max_requests_per_connection = 6; } -// [#next-free-field: 11] +// [#next-free-field: 12] message Http1ProtocolOptions { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.Http1ProtocolOptions"; @@ -299,9 +400,12 @@ message Http1ProtocolOptions { // Formats the header by proper casing words: the first character and any character following // a special character will be capitalized if it's an alpha character. For example, - // "content-type" becomes "Content-Type", and "foo$b#$are" becomes "Foo$B#$Are". - // Note that while this results in most headers following conventional casing, certain headers - // are not covered. For example, the "TE" header will be formatted as "Te". + // ``"content-type"`` becomes ``"Content-Type"``, and ``"foo$b#$are"`` becomes ``"Foo$B#$Are"``. + // + // .. note:: + // + // While this results in most headers following conventional casing, certain headers + // are not covered. For example, the ``"TE"`` header will be formatted as ``"Te"``. ProperCaseWords proper_case_words = 1; // Configuration for stateful formatter extensions that allow using received headers to @@ -317,7 +421,7 @@ message Http1ProtocolOptions { // ``http_proxy`` environment variable. google.protobuf.BoolValue allow_absolute_url = 1; - // Handle incoming HTTP/1.0 and HTTP 0.9 requests. + // Handle incoming HTTP/1.0 and HTTP/0.9 requests. // This is off by default, and not fully standards compliant. There is support for pre-HTTP/1.1 // style connect logic, dechunking, and handling lack of client host iff // ``default_host_for_http_10`` is configured. @@ -336,19 +440,20 @@ message Http1ProtocolOptions { // // .. attention:: // - // Note that this only happens when Envoy is chunk encoding which occurs when: + // This only happens when Envoy is chunk encoding which occurs when: // - The request is HTTP/1.1. - // - Is neither a HEAD only request nor a HTTP Upgrade. - // - Not a response to a HEAD request. - // - The content length header is not present. + // - Is neither a ``HEAD`` only request nor a HTTP Upgrade. + // - Not a response to a ``HEAD`` request. + // - The ``Content-Length`` header is not present. bool enable_trailers = 5; // Allows Envoy to process requests/responses with both ``Content-Length`` and ``Transfer-Encoding`` // headers set. By default such messages are rejected, but if option is enabled - Envoy will - // remove Content-Length header and process message. + // remove ``Content-Length`` header and process message. // See `RFC7230, sec. 3.3.3 `_ for details. // // .. attention:: + // // Enabling this option might lead to request smuggling vulnerability, especially if traffic // is proxied via multiple layers of proxies. // [#comment:TODO: This field is ignored when the @@ -377,7 +482,7 @@ message Http1ProtocolOptions { // envoy.reloadable_features.http1_use_balsa_parser. // See issue #21245. google.protobuf.BoolValue use_balsa_parser = 9 - [(xds.annotations.v3.field_status).work_in_progress = true]; + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // [#not-implemented-hide:] Hiding so that field can be removed. // If true, and BalsaParser is used (either `use_balsa_parser` above is true, @@ -391,6 +496,14 @@ message Http1ProtocolOptions { // ` // to reject custom methods. bool allow_custom_methods = 10 [(xds.annotations.v3.field_status).work_in_progress = true]; + + // Ignore HTTP/1.1 upgrade values matching any of the supplied matchers. + // + // .. note:: + // + // ``h2c`` upgrades are always removed for backwards compatibility, regardless of the + // value in this setting. + repeated type.matcher.v3.StringMatcher ignore_http_11_upgrade = 11; } message KeepaliveSettings { @@ -399,9 +512,12 @@ message KeepaliveSettings { google.protobuf.Duration interval = 1 [(validate.rules).duration = {gte {nanos: 1000000}}]; // How long to wait for a response to a keepalive PING. If a response is not received within this - // time period, the connection will be aborted. Note that in order to prevent the influence of - // Head-of-line (HOL) blocking the timeout period is extended when *any* frame is received on - // the connection, under the assumption that if a frame is received the connection is healthy. + // time period, the connection will be aborted. + // + // .. note:: + // + // In order to prevent the influence of Head-of-line (HOL) blocking the timeout period is extended when *any* frame is received on + // the connection, under the assumption that if a frame is received the connection is healthy. google.protobuf.Duration timeout = 2 [(validate.rules).duration = { required: true gte {nanos: 1000000} @@ -409,7 +525,7 @@ message KeepaliveSettings { // A random jitter amount as a percentage of interval that will be added to each interval. // A value of zero means there will be no jitter. - // The default value is 15%. + // The default value is ``15%``. type.v3.Percent interval_jitter = 3; // If the connection has been idle for this duration, send a HTTP/2 ping ahead @@ -423,7 +539,7 @@ message KeepaliveSettings { [(validate.rules).duration = {gte {nanos: 1000000}}]; } -// [#next-free-field: 17] +// [#next-free-field: 19] message Http2ProtocolOptions { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.Http2ProtocolOptions"; @@ -446,13 +562,13 @@ message Http2ProtocolOptions { // `Maximum table size `_ // (in octets) that the encoder is permitted to use for the dynamic HPACK table. Valid values - // range from 0 to 4294967295 (2^32 - 1) and defaults to 4096. 0 effectively disables header + // range from ``0`` to ``4294967295`` (``2^32 - 1``) and defaults to ``4096``. ``0`` effectively disables header // compression. google.protobuf.UInt32Value hpack_table_size = 1; // `Maximum concurrent streams `_ - // allowed for peer on one HTTP/2 connection. Valid values range from 1 to 2147483647 (2^31 - 1) - // and defaults to 2147483647. + // allowed for peer on one HTTP/2 connection. Valid values range from ``1`` to ``2147483647`` (``2^31 - 1``) + // and defaults to ``1024`` for safety and should be sufficient for most use cases. // // For upstream connections, this also limits how many streams Envoy will initiate concurrently // on a single connection. If the limit is reached, Envoy may queue requests or establish @@ -465,12 +581,14 @@ message Http2ProtocolOptions { [(validate.rules).uint32 = {lte: 2147483647 gte: 1}]; // `Initial stream-level flow-control window - // `_ size. Valid values range from 65535 - // (2^16 - 1, HTTP/2 default) to 2147483647 (2^31 - 1, HTTP/2 maximum) and defaults to 268435456 - // (256 * 1024 * 1024). + // `_ size. Valid values range from ``65535`` + // (``2^16 - 1``, HTTP/2 default) to ``2147483647`` (``2^31 - 1``, HTTP/2 maximum) and defaults to + // ``16MiB`` (``16 * 1024 * 1024``). // - // NOTE: 65535 is the initial window size from HTTP/2 spec. We only support increasing the default - // window size now, so it's also the minimum. + // .. note:: + // + // ``65535`` is the initial window size from HTTP/2 spec. We only support increasing the default window size now, + // so it's also the minimum. // // This field also acts as a soft limit on the number of bytes Envoy will buffer per-stream in the // HTTP/2 codec buffers. Once the buffer reaches this pointer, watermark callbacks will fire to @@ -479,7 +597,7 @@ message Http2ProtocolOptions { [(validate.rules).uint32 = {lte: 2147483647 gte: 65535}]; // Similar to ``initial_stream_window_size``, but for connection-level flow-control - // window. Currently, this has the same minimum/maximum/default as ``initial_stream_window_size``. + // window. The default is ``24MiB`` (``24 * 1024 * 1024``). google.protobuf.UInt32Value initial_connection_window_size = 4 [(validate.rules).uint32 = {lte: 2147483647 gte: 65535}]; @@ -497,51 +615,51 @@ message Http2ProtocolOptions { // Limit the number of pending outbound downstream frames of all types (frames that are waiting to // be written into the socket). Exceeding this limit triggers flood mitigation and connection is // terminated. The ``http2.outbound_flood`` stat tracks the number of terminated connections due - // to flood mitigation. The default limit is 10000. + // to flood mitigation. The default limit is ``10000``. google.protobuf.UInt32Value max_outbound_frames = 7 [(validate.rules).uint32 = {gte: 1}]; - // Limit the number of pending outbound downstream frames of types PING, SETTINGS and RST_STREAM, + // Limit the number of pending outbound downstream frames of types ``PING``, ``SETTINGS`` and ``RST_STREAM``, // preventing high memory utilization when receiving continuous stream of these frames. Exceeding // this limit triggers flood mitigation and connection is terminated. The // ``http2.outbound_control_flood`` stat tracks the number of terminated connections due to flood - // mitigation. The default limit is 1000. + // mitigation. The default limit is ``1000``. google.protobuf.UInt32Value max_outbound_control_frames = 8 [(validate.rules).uint32 = {gte: 1}]; - // Limit the number of consecutive inbound frames of types HEADERS, CONTINUATION and DATA with an + // Limit the number of consecutive inbound frames of types ``HEADERS``, ``CONTINUATION`` and ``DATA`` with an // empty payload and no end stream flag. Those frames have no legitimate use and are abusive, but - // might be a result of a broken HTTP/2 implementation. The `http2.inbound_empty_frames_flood`` + // might be a result of a broken HTTP/2 implementation. The ``http2.inbound_empty_frames_flood`` // stat tracks the number of connections terminated due to flood mitigation. - // Setting this to 0 will terminate connection upon receiving first frame with an empty payload - // and no end stream flag. The default limit is 1. + // Setting this to ``0`` will terminate connection upon receiving first frame with an empty payload + // and no end stream flag. The default limit is ``1``. google.protobuf.UInt32Value max_consecutive_inbound_frames_with_empty_payload = 9; - // Limit the number of inbound PRIORITY frames allowed per each opened stream. If the number - // of PRIORITY frames received over the lifetime of connection exceeds the value calculated + // Limit the number of inbound ``PRIORITY`` frames allowed per each opened stream. If the number + // of ``PRIORITY`` frames received over the lifetime of connection exceeds the value calculated // using this formula:: // // ``max_inbound_priority_frames_per_stream`` * (1 + ``opened_streams``) // // the connection is terminated. For downstream connections the ``opened_streams`` is incremented when // Envoy receives complete response headers from the upstream server. For upstream connection the - // ``opened_streams`` is incremented when Envoy send the HEADERS frame for a new stream. The + // ``opened_streams`` is incremented when Envoy sends the ``HEADERS`` frame for a new stream. The // ``http2.inbound_priority_frames_flood`` stat tracks - // the number of connections terminated due to flood mitigation. The default limit is 100. + // the number of connections terminated due to flood mitigation. The default limit is ``100``. google.protobuf.UInt32Value max_inbound_priority_frames_per_stream = 10; - // Limit the number of inbound WINDOW_UPDATE frames allowed per DATA frame sent. If the number - // of WINDOW_UPDATE frames received over the lifetime of connection exceeds the value calculated + // Limit the number of inbound ``WINDOW_UPDATE`` frames allowed per ``DATA`` frame sent. If the number + // of ``WINDOW_UPDATE`` frames received over the lifetime of connection exceeds the value calculated // using this formula:: // - // 5 + 2 * (``opened_streams`` + - // ``max_inbound_window_update_frames_per_data_frame_sent`` * ``outbound_data_frames``) + // ``5 + 2 * (opened_streams + + // max_inbound_window_update_frames_per_data_frame_sent * outbound_data_frames)`` // // the connection is terminated. For downstream connections the ``opened_streams`` is incremented when // Envoy receives complete response headers from the upstream server. For upstream connections the - // ``opened_streams`` is incremented when Envoy sends the HEADERS frame for a new stream. The + // ``opened_streams`` is incremented when Envoy sends the ``HEADERS`` frame for a new stream. The // ``http2.inbound_priority_frames_flood`` stat tracks the number of connections terminated due to - // flood mitigation. The default max_inbound_window_update_frames_per_data_frame_sent value is 10. - // Setting this to 1 should be enough to support HTTP/2 implementations with basic flow control, - // but more complex implementations that try to estimate available bandwidth require at least 2. + // flood mitigation. The default ``max_inbound_window_update_frames_per_data_frame_sent`` value is ``10``. + // Setting this to ``1`` should be enough to support HTTP/2 implementations with basic flow control, + // but more complex implementations that try to estimate available bandwidth require at least ``2``. google.protobuf.UInt32Value max_inbound_window_update_frames_per_data_frame_sent = 11 [(validate.rules).uint32 = {gte: 1}]; @@ -579,8 +697,10 @@ message Http2ProtocolOptions { // 2. SETTINGS_ENABLE_CONNECT_PROTOCOL (0x8) is only configurable through the named field // 'allow_connect'. // - // Note that custom parameters specified through this field can not also be set in the - // corresponding named parameters: + // .. note:: + // + // Custom parameters specified through this field can not also be set in the + // corresponding named parameters: // // .. code-block:: text // @@ -607,6 +727,15 @@ message Http2ProtocolOptions { // If unset, HTTP/2 codec is selected based on envoy.reloadable_features.http2_use_oghttp2. google.protobuf.BoolValue use_oghttp2_codec = 16 [(xds.annotations.v3.field_status).work_in_progress = true]; + + // Configure the maximum amount of metadata than can be handled per stream. Defaults to ``1 MB``. + google.protobuf.UInt64Value max_metadata_size = 17; + + // Controls whether to encode headers using huffman encoding. + // This can be useful in cases where the cpu spent encoding the headers isn't + // worth the network bandwidth saved e.g. for localhost. + // If unset, uses the data plane's default value. + google.protobuf.BoolValue enable_huffman_encoding = 18; } // [#not-implemented-hide:] @@ -618,7 +747,7 @@ message GrpcProtocolOptions { } // A message which allows using HTTP/3. -// [#next-free-field: 7] +// [#next-free-field: 9] message Http3ProtocolOptions { QuicProtocolOptions quic_protocol_options = 1; @@ -635,7 +764,10 @@ message Http3ProtocolOptions { // `_ // and settings `proposed for HTTP/3 // `_ - // Note that HTTP/3 CONNECT is not yet an RFC. + // + // .. note:: + // + // HTTP/3 CONNECT is not yet an RFC. bool allow_extended_connect = 5 [(xds.annotations.v3.field_status).work_in_progress = true]; // [#not-implemented-hide:] Hiding until Envoy has full metadata support. @@ -645,19 +777,31 @@ message Http3ProtocolOptions { // docs](https://github.com/envoyproxy/envoy/blob/main/source/docs/h2_metadata.md) for more // information. bool allow_metadata = 6; + + // [#not-implemented-hide:] Hiding until Envoy has full HTTP/3 upstream support. + // Still under implementation. DO NOT USE. + // + // Disables QPACK compression related features for HTTP/3 including: + // No huffman encoding, zero dynamic table capacity and no cookie crumbling. + // This can be useful for trading off CPU vs bandwidth when an upstream HTTP/3 connection multiplexes multiple downstream connections. + bool disable_qpack = 7; + + // Disables connection level flow control for HTTP/3 streams. This is useful in situations where the streams share the same connection + // but originate from different end-clients, so that each stream can make progress independently at non-front-line proxies. + bool disable_connection_flow_control_for_streams = 8; } // A message to control transformations to the :scheme header message SchemeHeaderTransformation { oneof transformation { // Overwrite any Scheme header with the contents of this string. - // If set, takes precedence over match_upstream. + // If set, takes precedence over ``match_upstream``. string scheme_to_overwrite = 1 [(validate.rules).string = {in: "http" in: "https"}]; } // Set the Scheme header to match the upstream transport protocol. For example, should a - // request be sent to the upstream over TLS, the scheme header will be set to "https". Should the - // request be sent over plaintext, the scheme header will be set to "http". - // If scheme_to_overwrite is set, this field is not used. + // request be sent to the upstream over TLS, the scheme header will be set to ``"https"``. Should the + // request be sent over plaintext, the scheme header will be set to ``"http"``. + // If ``scheme_to_overwrite`` is set, this field is not used. bool match_upstream = 2; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/proxy_protocol.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/proxy_protocol.proto index 32747dd2288..2da5fe5fd4d 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/proxy_protocol.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/proxy_protocol.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package envoy.config.core.v3; +import "envoy/config/core/v3/substitution_format_string.proto"; + import "udpa/annotations/status.proto"; import "validate/validate.proto"; @@ -32,6 +34,34 @@ message ProxyProtocolPassThroughTLVs { repeated uint32 tlv_type = 2 [(validate.rules).repeated = {items {uint32 {lt: 256}}}]; } +// Represents a single Type-Length-Value (TLV) entry. +message TlvEntry { + // The type of the TLV. Must be a uint8 (0-255) as per the Proxy Protocol v2 specification. + uint32 type = 1 [(validate.rules).uint32 = {lt: 256}]; + + // The static value of the TLV. + // Only one of ``value`` or ``format_string`` may be set. + bytes value = 2; + + // Uses the :ref:`format string ` to dynamically + // populate the TLV value from stream information. This allows dynamic values + // such as metadata, filter state, or other stream properties to be included in + // the TLV. + // + // For example: + // + // .. code-block:: yaml + // + // type: 0xF0 + // format_string: + // text_format_source: + // inline_string: "%DYNAMIC_METADATA(envoy.filters.network:key)%" + // + // The formatted string will be used directly as the TLV value. + // Only one of ``value`` or ``format_string`` may be set. + SubstitutionFormatString format_string = 3; +} + message ProxyProtocolConfig { enum Version { // PROXY protocol version 1. Human readable format. @@ -47,4 +77,38 @@ message ProxyProtocolConfig { // This config controls which TLVs can be passed to upstream if it is Proxy Protocol // V2 header. If there is no setting for this field, no TLVs will be passed through. ProxyProtocolPassThroughTLVs pass_through_tlvs = 2; + + // This config allows additional TLVs to be included in the upstream PROXY protocol + // V2 header. Unlike ``pass_through_tlvs``, which passes TLVs from the downstream request, + // ``added_tlvs`` provides an extension mechanism for defining new TLVs that are included + // with the upstream request. These TLVs may not be present in the downstream request and + // can be defined at either the transport socket level or the host level to provide more + // granular control over the TLVs that are included in the upstream request. + // + // Host-level TLVs are specified in the ``metadata.typed_filter_metadata`` field under the + // ``envoy.transport_sockets.proxy_protocol`` namespace. + // + // .. literalinclude:: /_configs/repo/proxy_protocol.yaml + // :language: yaml + // :lines: 49-57 + // :linenos: + // :lineno-start: 49 + // :caption: :download:`proxy_protocol.yaml ` + // + // **Precedence behavior**: + // + // - When a TLV is defined at both the host level and the transport socket level, the value + // from the host level configuration takes precedence. This allows users to define default TLVs + // at the transport socket level and override them at the host level. + // - Any TLV defined in the ``pass_through_tlvs`` field will be overridden by either the host-level + // or transport socket-level TLV. + // + // If there are multiple TLVs with the same type, only the TLVs from the highest precedence level + // will be used. + repeated TlvEntry added_tlvs = 3; +} + +message PerHostConfig { + // Enables per-host configuration for Proxy Protocol. + repeated TlvEntry added_tlvs = 1; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_cmsg_headers.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_cmsg_headers.proto new file mode 100644 index 00000000000..cc3e58e0996 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_cmsg_headers.proto @@ -0,0 +1,28 @@ +syntax = "proto3"; + +package envoy.config.core.v3; + +import "google/protobuf/wrappers.proto"; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.config.core.v3"; +option java_outer_classname = "SocketCmsgHeadersProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Socket CMSG headers] + +// Configuration for socket cmsg headers. +// See `:ref:CMSG `_ for further information. +message SocketCmsgHeaders { + // cmsg level. Default is unset. + google.protobuf.UInt32Value level = 1; + + // cmsg type. Default is unset. + google.protobuf.UInt32Value type = 2; + + // Expected size of cmsg value. Default is zero. + uint32 expected_size = 3; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_option.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_option.proto index 44f1ce3890a..ad73d72e490 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_option.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_option.proto @@ -36,7 +36,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // :ref:`admin's ` socket_options etc. // // It should be noted that the name or level may have different values on different platforms. -// [#next-free-field: 7] +// [#next-free-field: 8] message SocketOption { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.SocketOption"; @@ -51,6 +51,29 @@ message SocketOption { STATE_LISTENING = 2; } + // The `socket type `_ to apply the socket option to. + // Only one field should be set. If multiple fields are set, the precedence order will determine + // the selected one. If none of the fields is set, the socket option will be applied to all socket types. + // + // For example: + // If :ref:`stream ` is set, + // it takes precedence over :ref:`datagram `. + message SocketType { + // The stream socket type. + message Stream { + } + + // The datagram socket type. + message Datagram { + } + + // Apply the socket option to the stream socket type. + Stream stream = 1; + + // Apply the socket option to the datagram socket type. + Datagram datagram = 2; + } + // An optional name to give this socket option for debugging, etc. // Uniqueness is not required and no special meaning is assumed. string description = 1; @@ -74,6 +97,10 @@ message SocketOption { // The state in which the option will be applied. When used in BindConfig // STATE_PREBIND is currently the only valid value. SocketState state = 6 [(validate.rules).enum = {defined_only: true}]; + + // Apply the socket option to the specified `socket type `_. + // If not specified, the socket option will be applied to all socket types. + SocketType type = 7; } message SocketOptionsOverride { diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/substitution_format_string.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/substitution_format_string.proto index abe8afa68ae..3edbf5f5f00 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/substitution_format_string.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/substitution_format_string.proto @@ -22,7 +22,12 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Optional configuration options to be used with json_format. message JsonFormatOptions { // The output JSON string properties will be sorted. - bool sort_properties = 1; + // + // .. note:: + // As the properties are always sorted, this option has no effect and is deprecated. + // + bool sort_properties = 1 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; } // Configuration to use multiple :ref:`command operators ` @@ -101,6 +106,12 @@ message SubstitutionFormatString { // * for ``text_format``, the output of the empty operator is changed from ``-`` to an // empty string, so that empty values are omitted entirely. // * for ``json_format`` the keys with null values are omitted in the output structure. + // + // .. note:: + // This option does not work perfectly with ``json_format`` as keys with ``null`` values + // will still be included in the output. See https://github.com/envoyproxy/envoy/issues/37941 + // for more details. + // bool omit_empty_values = 3; // Specify a ``content_type`` field. diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint.proto b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint.proto index 894f68310a4..a149f6095c1 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint.proto @@ -113,8 +113,9 @@ message ClusterLoadAssignment { // to determine the health of the priority level, or in other words assume each host has a weight of 1 for // this calculation. // - // Note: this is not currently implemented for - // :ref:`locality weighted load balancing `. + // .. note:: + // This is not currently implemented for + // :ref:`locality weighted load balancing `. bool weighted_priority_health = 6; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto index 6673691105e..eacc555df73 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto @@ -9,6 +9,9 @@ import "envoy/config/core/v3/health_check.proto"; import "google/protobuf/wrappers.proto"; +import "xds/core/v3/collection_entry.proto"; + +import "envoy/annotations/deprecation.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -133,14 +136,24 @@ message LbEndpoint { google.protobuf.UInt32Value load_balancing_weight = 4 [(validate.rules).uint32 = {gte: 1}]; } +// LbEndpoint list collection. Entries are `LbEndpoint` resources or references. // [#not-implemented-hide:] -// A configuration for a LEDS collection. +message LbEndpointCollection { + xds.core.v3.CollectionEntry entries = 1; +} + +// A configuration for an LEDS collection. message LedsClusterLocalityConfig { // Configuration for the source of LEDS updates for a Locality. core.v3.ConfigSource leds_config = 1; - // The xDS transport protocol glob collection resource name. - // The service is only supported in delta xDS (incremental) mode. + // The name of the LbEndpoint collection resource. + // + // If the name ends in ``/*``, it indicates an LbEndpoint glob collection, + // which is supported only in the xDS incremental protocol variants. + // Otherwise, it indicates an LbEndpointCollection list collection. + // + // Envoy currently supports only glob collections. string leds_collection_name = 2; } @@ -165,18 +178,20 @@ message LocalityLbEndpoints { core.v3.Metadata metadata = 9; // The group of endpoints belonging to the locality specified. - // [#comment:TODO(adisuissa): Once LEDS is implemented this field needs to be - // deprecated and replaced by ``load_balancer_endpoints``.] + // This is ignored if :ref:`leds_cluster_locality_config + // ` is set. repeated LbEndpoint lb_endpoints = 2; - // [#not-implemented-hide:] oneof lb_config { - // The group of endpoints belonging to the locality. - // [#comment:TODO(adisuissa): Once LEDS is implemented the ``lb_endpoints`` field - // needs to be deprecated.] - LbEndpointList load_balancer_endpoints = 7; + // [#not-implemented-hide:] + // Not implemented and deprecated. + LbEndpointList load_balancer_endpoints = 7 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // LEDS Configuration for the current locality. + // If this is set, the :ref:`lb_endpoints + // ` + // field is ignored. LedsClusterLocalityConfig leds_cluster_locality_config = 8; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/load_report.proto b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/load_report.proto index fbd1d36d5d0..6d12765cef5 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/load_report.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/load_report.proto @@ -25,7 +25,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // These are stats Envoy reports to the management server at a frequency defined by // :ref:`LoadStatsResponse.load_reporting_interval`. // Stats per upstream region/zone and optionally per subzone. -// [#next-free-field: 12] +// [#next-free-field: 15] message UpstreamLocalityStats { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.endpoint.UpstreamLocalityStats"; @@ -38,7 +38,8 @@ message UpstreamLocalityStats { // locality. uint64 total_successful_requests = 2; - // The total number of unfinished requests + // The total number of unfinished requests. A request can be an HTTP request + // or a TCP connection for a TCP connection pool. uint64 total_requests_in_progress = 3; // The total number of requests that failed due to errors at the endpoint, @@ -47,7 +48,8 @@ message UpstreamLocalityStats { // The total number of requests that were issued by this Envoy since // the last report. This information is aggregated over all the - // upstream endpoints in the locality. + // upstream endpoints in the locality. A request can be an HTTP request + // or a TCP connection for a TCP connection pool. uint64 total_issued_requests = 8; // The total number of connections in an established state at the time of the @@ -75,7 +77,20 @@ message UpstreamLocalityStats { // [#not-implemented-hide:] uint64 total_fail_connections = 11 [(xds.annotations.v3.field_status).work_in_progress = true]; - // Stats for multi-dimensional load balancing. + // CPU utilization stats for multi-dimensional load balancing. + // This typically comes from endpoint metrics reported via ORCA. + UnnamedEndpointLoadMetricStats cpu_utilization = 12; + + // Memory utilization for multi-dimensional load balancing. + // This typically comes from endpoint metrics reported via ORCA. + UnnamedEndpointLoadMetricStats mem_utilization = 13; + + // Blended application-defined utilization for multi-dimensional load balancing. + // This typically comes from endpoint metrics reported via ORCA. + UnnamedEndpointLoadMetricStats application_utilization = 14; + + // Named stats for multi-dimensional load balancing. + // These typically come from endpoint metrics reported via ORCA. repeated EndpointLoadMetricStats load_metric_stats = 5; // Endpoint granularity stats information for this locality. This information @@ -145,6 +160,16 @@ message EndpointLoadMetricStats { double total_metric_value = 3; } +// Same as EndpointLoadMetricStats, except without the metric_name field. +message UnnamedEndpointLoadMetricStats { + // Number of calls that finished and included this metric. + uint64 num_requests_finished_with_metric = 1; + + // Sum of metric values across all calls that finished with this metric for + // load_reporting_interval. + double total_metric_value = 2; +} + // Per cluster load stats. Envoy reports these stats a management server in a // :ref:`LoadStatsRequest` // Next ID: 7 diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto index 9381d4eb7ac..54ef2cfed38 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto @@ -5,6 +5,7 @@ package envoy.config.listener.v3; import "envoy/config/accesslog/v3/accesslog.proto"; import "envoy/config/core/v3/address.proto"; import "envoy/config/core/v3/base.proto"; +import "envoy/config/core/v3/config_source.proto"; import "envoy/config/core/v3/extension.proto"; import "envoy/config/core/v3/socket_option.proto"; import "envoy/config/listener/v3/api_listener.proto"; @@ -14,7 +15,6 @@ import "envoy/config/listener/v3/udp_listener_config.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; -import "xds/annotations/v3/status.proto"; import "xds/core/v3/collection_entry.proto"; import "xds/type/matcher/v3/matcher.proto"; @@ -45,6 +45,14 @@ message AdditionalAddress { // or an empty list of :ref:`socket_options `, // it means no socket option will apply. core.v3.SocketOptionsOverride socket_options = 2; + + // Configures TCP keepalive settings for the additional address. + // If not set, the listener :ref:`tcp_keepalive ` + // configuration is inherited. You can explicitly disable TCP keepalive for the additional address by setting any keepalive field + // (:ref:`keepalive_probes `, + // :ref:`keepalive_time `, or + // :ref:`keepalive_interval `) to ``0``. + core.v3.TcpKeepalive tcp_keepalive = 3; } // Listener list collections. Entries are ``Listener`` resources or references. @@ -53,7 +61,7 @@ message ListenerCollection { repeated xds.core.v3.CollectionEntry entries = 1; } -// [#next-free-field: 36] +// [#next-free-field: 38] message Listener { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Listener"; @@ -115,6 +123,20 @@ message Listener { message InternalListenerConfig { } + // Configuration for filter chains discovery. + // [#not-implemented-hide:] + message FcdsConfig { + // Optional name to present to the filter chain discovery service. This may be an arbitrary name with arbitrary + // length. If a name is not provided, the listener's name is used. Refer to :ref:`filter_chains `. + // for details on how listener name is determined if unspecified. In addition, this may be a xdstp:// URL. + string name = 1; + + // Configuration for the source of FCDS updates for this listener. + // .. note:: + // This discovery service only supports ``AGGREGATED_GRPC`` API type. + core.v3.ConfigSource config_source = 2; + } + reserved 14, 23; // The unique name by which this listener is known. If no name is provided, @@ -126,6 +148,12 @@ message Listener { // that is governed by the bind rules of the OS. E.g., multiple listeners can listen on port 0 on // Linux as the actual port will be allocated by the OS. // Required unless ``api_listener`` or ``listener_specifier`` is populated. + // + // When the address contains a network namespace filepath (via + // :ref:`network_namespace_filepath `), + // Envoy automatically populates the filter state with key ``envoy.network.network_namespace`` + // when a connection is accepted. This provides read-only access to the network namespace for + // filters, access logs, and other components. core.v3.Address address = 2; // The additional addresses the listener should listen on. The addresses must be unique across all @@ -147,6 +175,12 @@ message Listener { // :ref:`FAQ entry `. repeated FilterChain filter_chains = 3; + // Discover filter chains configurations by external service. Dynamic discovery of filter chains is allowed + // while having statically configured filter chains, however, a filter chain name must be unique within a + // listener. If a discovered filter chain matches a name of an existing filter chain, it is discarded. + // [#not-implemented-hide:] + FcdsConfig fcds_config = 36; + // :ref:`Matcher API ` resolving the filter chain name from the // network properties. This matcher is used as a replacement for the filter chain match condition // :ref:`filter_chain_match @@ -163,8 +197,7 @@ message Listener { // connections bound to the filter chain are not drained. If, however, the // filter chain is removed or structurally modified, then the drain for its // connections is initiated. - xds.type.matcher.v3.Matcher filter_chain_matcher = 32 - [(xds.annotations.v3.field_status).work_in_progress = true]; + xds.type.matcher.v3.Matcher filter_chain_matcher = 32; // If a connection is redirected using ``iptables``, the port on which the proxy // receives it might be different from the original destination address. When this flag is set to @@ -247,10 +280,10 @@ message Listener { google.protobuf.BoolValue freebind = 11; // Additional socket options that may not be present in Envoy source code or - // precompiled binaries. The socket options can be updated for a listener when + // precompiled binaries. + // It is not allowed to update the socket options for any existing address if // :ref:`enable_reuse_port ` - // is ``true``. Otherwise, if socket options change during a listener update the update will be rejected - // to make it clear that the options were not updated. + // is ``false`` to avoid the conflict when creating new sockets for the listener. repeated core.v3.SocketOption socket_options = 13; // Whether the listener should accept TCP Fast Open (TFO) connections. @@ -352,6 +385,11 @@ message Listener { // accepted in later event loop iterations. // If no value is provided Envoy will accept all connections pending accept // from the kernel. + // + // .. note:: + // + // It is recommended to lower this value for better overload management and reduced per-event cost. + // Setting it to 1 is a viable option with no noticeable impact on performance. google.protobuf.UInt32Value max_connections_to_accept_per_socket_event = 34 [(validate.rules).uint32 = {gt: 0}]; @@ -390,6 +428,12 @@ message Listener { // Whether the listener bypasses configured overload manager actions. bool bypass_overload_manager = 35; + + // If set, TCP keepalive settings are configured for the listener address and inherited by + // additional addresses. If not set, TCP keepalive settings are not configured for the + // listener address and additional addresses by default. See :ref:`tcp_keepalive ` + // to explicitly configure TCP keepalive settings for individual additional addresses. + core.v3.TcpKeepalive tcp_keepalive = 37; } // A placeholder proto so that users can explicitly configure the standard diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto index 2adb8bc2c80..16b43568f39 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto @@ -201,24 +201,9 @@ message FilterChainMatch { message FilterChain { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.listener.FilterChain"; - // The configuration for on-demand filter chain. If this field is not empty in FilterChain message, - // a filter chain will be built on-demand. - // On-demand filter chains help speedup the warming up of listeners since the building and initialization of - // an on-demand filter chain will be postponed to the arrival of new connection requests that require this filter chain. - // Filter chains that are not often used can be set as on-demand. - message OnDemandConfiguration { - // The timeout to wait for filter chain placeholders to complete rebuilding. - // 1. If this field is set to 0, timeout is disabled. - // 2. If not specified, a default timeout of 15s is used. - // Rebuilding will wait until dependencies are ready, have failed, or this timeout is reached. - // Upon failure or timeout, all connections related to this filter chain will be closed. - // Rebuilding will start again on the next new connection. - google.protobuf.Duration rebuild_timeout = 1; - } - - reserved 2; + reserved 2, 8; - reserved "tls_context"; + reserved "tls_context", "on_demand_configuration"; // The criteria to use when matching a connection to this filter chain. FilterChainMatch filter_chain_match = 1; @@ -248,7 +233,7 @@ message FilterChain { google.protobuf.BoolValue use_proxy_proto = 4 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - // [#not-implemented-hide:] filter chain metadata. + // Filter chain metadata. core.v3.Metadata metadata = 5; // Optional custom transport socket implementation to use for downstream connections. @@ -265,15 +250,12 @@ message FilterChain { google.protobuf.Duration transport_socket_connect_timeout = 9; // The unique name (or empty) by which this filter chain is known. - // Note: :ref:`filter_chain_matcher - // ` - // requires that filter chains are uniquely named within a listener. + // + // .. note:: + // :ref:`filter_chain_matcher + // ` + // requires that filter chains are uniquely named within a listener. string name = 7; - - // [#not-implemented-hide:] The configuration to specify whether the filter chain will be built on-demand. - // If this field is not empty, the filter chain will be built on-demand. - // Otherwise, the filter chain will be built normally and block listener warming. - OnDemandConfiguration on_demand_configuration = 8; } // Listener filter chain match configuration. This is a recursive structure which allows complex diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/quic_config.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/quic_config.proto index 3ddebe900ef..c208a58f4a4 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/quic_config.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/quic_config.proto @@ -5,6 +5,7 @@ package envoy.config.listener.v3; import "envoy/config/core/v3/base.proto"; import "envoy/config/core/v3/extension.proto"; import "envoy/config/core/v3/protocol.proto"; +import "envoy/config/core/v3/socket_cmsg_headers.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; @@ -24,7 +25,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: QUIC listener config] // Configuration specific to the UDP QUIC listener. -// [#next-free-field: 12] +// [#next-free-field: 15] message QuicProtocolOptions { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.listener.QuicProtocolOptions"; @@ -86,4 +87,22 @@ message QuicProtocolOptions { // If not specified, no debug visitor will be attached to connections. // [#extension-category: envoy.quic.connection_debug_visitor] core.v3.TypedExtensionConfig connection_debug_visitor_config = 11; + + // Configure a type of UDP cmsg to pass to listener filters via QuicReceivedPacket. + // Both level and type must be specified for cmsg to be saved. + // Cmsg may be truncated or omitted if expected size is not set. + // If not specified, no cmsg will be saved to QuicReceivedPacket. + repeated core.v3.SocketCmsgHeaders save_cmsg_config = 12 + [(validate.rules).repeated = {max_items: 1}]; + + // If true, the listener will reject connection-establishing packets at the + // QUIC layer by replying with an empty version negotiation packet to the + // client. + bool reject_new_connections = 13; + + // Maximum number of QUIC sessions to create per event loop. + // If not specified, the default value is 16. + // This is an equivalent of the TCP listener option + // max_connections_to_accept_per_socket_event. + google.protobuf.UInt32Value max_sessions_per_event_loop = 14 [(validate.rules).uint32 = {gt: 0}]; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/metrics/v3/stats.proto b/xds/third_party/envoy/src/main/proto/envoy/config/metrics/v3/stats.proto index e7d7f80d648..0fcf36c1c71 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/metrics/v3/stats.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/metrics/v3/stats.proto @@ -60,11 +60,6 @@ message StatsConfig { // `. They will be processed before // the custom tags. // - // .. note:: - // - // If any default tags are specified twice, the config will be considered - // invalid. - // // See :repo:`well_known_names.h ` for a list of the // default tags in Envoy. // @@ -298,10 +293,12 @@ message HistogramBucketSettings { // Each value is the upper bound of a bucket. Each bucket must be greater than 0 and unique. // The order of the buckets does not matter. repeated double buckets = 2 [(validate.rules).repeated = { - min_items: 1 unique: true items {double {gt: 0.0}} }]; + + // Initial number of bins for the ``circllhist`` thread local histogram per time series. Default value is 100. + google.protobuf.UInt32Value bins = 3 [(validate.rules).uint32 = {lte: 46082 gt: 0}]; } // Stats configuration proto schema for built-in ``envoy.stat_sinks.statsd`` sink. This sink does not support diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto b/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto index d3b8b01a173..b5bc2c4d830 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto @@ -103,6 +103,19 @@ message ScaleTimersOverloadActionConfig { // This affects the value of // :ref:`FilterChain.transport_socket_connect_timeout `. TRANSPORT_SOCKET_CONNECT = 3; + + // Adjusts the max connection duration timer for downstream HTTP connections. + // This affects the value of + // :ref:`HttpConnectionManager.common_http_protocol_options.max_connection_duration + // `. + HTTP_DOWNSTREAM_CONNECTION_MAX = 4; + + // Adjusts the timeout for the downstream codec to flush an ended stream. + // This affects the value of :ref:`RouteAction.flush_timeout + // ` and + // :ref:`HttpConnectionManager.stream_flush_timeout + // ` + HTTP_DOWNSTREAM_STREAM_FLUSH = 5; } message ScaleTimer { @@ -128,9 +141,16 @@ message OverloadAction { option (udpa.annotations.versioning).previous_message_type = "envoy.config.overload.v2alpha.OverloadAction"; - // The name of the overload action. This is just a well-known string that listeners can - // use for registering callbacks. Custom overload actions should be named using reverse - // DNS to ensure uniqueness. + // The name of the overload action. This is just a well-known string that + // listeners can use for registering callbacks. + // Valid known overload actions include: + // - envoy.overload_actions.stop_accepting_requests + // - envoy.overload_actions.disable_http_keepalive + // - envoy.overload_actions.stop_accepting_connections + // - envoy.overload_actions.reject_incoming_connections + // - envoy.overload_actions.shrink_heap + // - envoy.overload_actions.reduce_timeouts + // - envoy.overload_actions.reset_high_memory_stream string name = 1 [(validate.rules).string = {min_len: 1}]; // A set of triggers for this action. The state of the action is the maximum @@ -142,7 +162,7 @@ message OverloadAction { // in this list. repeated Trigger triggers = 2 [(validate.rules).repeated = {min_items: 1}]; - // Configuration for the action being instantiated. + // Configuration for the action being instantiated if applicable. google.protobuf.Any typed_config = 3; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto b/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto index 8d98fd7155d..ef153ad177b 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.config.rbac.v3; import "envoy/config/core/v3/address.proto"; +import "envoy/config/core/v3/cel.proto"; import "envoy/config/core/v3/extension.proto"; import "envoy/config/route/v3/route_components.proto"; import "envoy/type/matcher/v3/filter_state.proto"; @@ -28,6 +29,14 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Role Based Access Control (RBAC)] +enum MetadataSource { + // Query :ref:`dynamic metadata ` + DYNAMIC = 0; + + // Query :ref:`route metadata ` + ROUTE = 1; +} + // Role Based Access Control (RBAC) provides service-level and method-level access control for a // service. Requests are allowed or denied based on the ``action`` and whether a matching policy is // found. For instance, if the action is ALLOW and a matching policy is found the request should be @@ -165,6 +174,7 @@ message RBAC { // A policy matches if and only if at least one of its permissions match the // action taking place AND at least one of its principals match the downstream // AND the condition is true if specified. +// [#next-free-field: 6] message Policy { option (udpa.annotations.versioning).previous_message_type = "envoy.config.rbac.v2.Policy"; @@ -191,10 +201,37 @@ message Policy { // Only be used when condition is not used. google.api.expr.v1alpha1.CheckedExpr checked_condition = 4 [(udpa.annotations.field_migrate).oneof_promotion = "expression_specifier"]; + + // CEL expression configuration that modifies the evaluation behavior of the ``condition`` field. + // If specified, string conversion, concatenation, and manipulation functions may be enabled + // for the CEL expression. See :ref:`CelExpressionConfig ` + // for more details. + core.v3.CelExpressionConfig cel_config = 5; +} + +// SourcedMetadata enables matching against metadata from different sources in the request processing +// pipeline. It extends the base MetadataMatcher functionality by allowing specification of where the +// metadata should be sourced from, rather than only matching against dynamic metadata. +// +// The matcher can be configured to look up metadata from: +// +// * Dynamic metadata: Runtime metadata added by filters during request processing +// * Route metadata: Static metadata configured on the route entry +// +message SourcedMetadata { + // Metadata matcher configuration that defines what metadata to match against. This includes the filter name, + // metadata key path, and expected value. + type.matcher.v3.MetadataMatcher metadata_matcher = 1 + [(validate.rules).message = {required: true}]; + + // Specifies which metadata source should be used for matching. If not set, + // defaults to DYNAMIC (dynamic metadata). Set to ROUTE to match against + // static metadata configured on the route entry. + MetadataSource metadata_source = 2 [(validate.rules).enum = {defined_only: true}]; } // Permission defines an action (or actions) that a principal can take. -// [#next-free-field: 14] +// [#next-free-field: 15] message Permission { option (udpa.annotations.versioning).previous_message_type = "envoy.config.rbac.v2.Permission"; @@ -219,10 +256,14 @@ message Permission { // When any is set, it matches any action. bool any = 3 [(validate.rules).bool = {const: true}]; - // A header (or pseudo-header such as :path or :method) on the incoming HTTP request. Only - // available for HTTP request. - // Note: the pseudo-header :path includes the query and fragment string. Use the ``url_path`` - // field if you want to match the URL path without the query and fragment string. + // A header (or pseudo-header such as ``:path`` or ``:method``) on the incoming HTTP request. Only available + // for HTTP request. + // + // .. note:: + // + // The pseudo-header ``:path`` includes the query and fragment string. Use the ``url_path`` field if you + // want to match the URL path without the query and fragment string. + // route.v3.HeaderMatcher header = 4; // A URL path on the incoming HTTP request. Only available for HTTP. @@ -237,16 +278,17 @@ message Permission { // A port number range that describes a range of destination ports connecting to. type.v3.Int32Range destination_port_range = 11; - // Metadata that describes additional information about the action. - type.matcher.v3.MetadataMatcher metadata = 7; + // Metadata that describes additional information about the action. This field is deprecated; please use + // :ref:`sourced_metadata` instead. + type.matcher.v3.MetadataMatcher metadata = 7 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Negates matching the provided permission. For instance, if the value of // ``not_rule`` would match, this permission would not match. Conversely, if // the value of ``not_rule`` would not match, this permission would match. Permission not_rule = 8; - // The request server from the client's connection request. This is - // typically TLS SNI. + // The request server from the client's connection request. This is typically TLS SNI. // // .. attention:: // @@ -263,8 +305,7 @@ message Permission { // * A :ref:`listener filter ` may // overwrite a connection's requested server name within Envoy. // - // Please refer to :ref:`this FAQ entry ` to learn to - // setup SNI. + // Please refer to :ref:`this FAQ entry ` to learn how to setup SNI. type.matcher.v3.StringMatcher requested_server_name = 9; // Extension for configuring custom matchers for RBAC. @@ -274,12 +315,16 @@ message Permission { // URI template path matching. // [#extension-category: envoy.path.match] core.v3.TypedExtensionConfig uri_template = 13; + + // Matches against metadata from either dynamic state or route configuration. Preferred over the + // ``metadata`` field as it provides more flexibility in metadata source selection. + SourcedMetadata sourced_metadata = 14; } } // Principal defines an identity or a group of identities for a downstream // subject. -// [#next-free-field: 13] +// [#next-free-field: 15] message Principal { option (udpa.annotations.versioning).previous_message_type = "envoy.config.rbac.v2.Principal"; @@ -293,6 +338,10 @@ message Principal { } // Authentication attributes for a downstream. + // It is recommended to NOT use this type, but instead use + // :ref:`MTlsAuthenticated `, + // configured via :ref:`custom `, + // which should be used for most use cases due to its improved security. message Authenticated { option (udpa.annotations.versioning).previous_message_type = "envoy.config.rbac.v2.Principal.Authenticated"; @@ -301,25 +350,31 @@ message Principal { // The name of the principal. If set, The URI SAN or DNS SAN in that order // is used from the certificate, otherwise the subject field is used. If - // unset, it applies to any user that is authenticated. + // unset, it applies to any user that is allowed by the downstream TLS configuration. + // If :ref:`require_client_certificate ` + // is false or :ref:`trust_chain_verification ` + // is set to :ref:`ACCEPT_UNTRUSTED `, + // then no authentication is required. type.matcher.v3.StringMatcher principal_name = 2; } oneof identifier { option (validate.required) = true; - // A set of identifiers that all must match in order to define the - // downstream. + // A set of identifiers that all must match in order to define the downstream. Set and_ids = 1; - // A set of identifiers at least one must match in order to define the - // downstream. + // A set of identifiers at least one must match in order to define the downstream. Set or_ids = 2; // When any is set, it matches any downstream. bool any = 3 [(validate.rules).bool = {const: true}]; // Authenticated attributes that identify the downstream. + // It is recommended to NOT use this field, but instead use + // :ref:`MTlsAuthenticated `, + // configured via :ref:`custom `, + // which should be used for most use cases due to its improved security. Authenticated authenticated = 4; // A CIDR block that describes the downstream IP. @@ -333,31 +388,42 @@ message Principal { [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // A CIDR block that describes the downstream remote/origin address. - // Note: This is always the physical peer even if the - // :ref:`remote_ip ` is - // inferred from for example the x-forwarder-for header, proxy protocol, - // etc. + // + // .. note:: + // + // This is always the physical peer even if the + // :ref:`remote_ip ` is inferred from the + // x-forwarder-for header, the proxy protocol, etc. + // core.v3.CidrRange direct_remote_ip = 10; // A CIDR block that describes the downstream remote/origin address. - // Note: This may not be the physical peer and could be different from the - // :ref:`direct_remote_ip - // `. E.g, if the - // remote ip is inferred from for example the x-forwarder-for header, proxy - // protocol, etc. + // + // .. note:: + // + // This may not be the physical peer and could be different from the :ref:`direct_remote_ip + // `. E.g, if the remote ip is inferred from + // the x-forwarder-for header, the proxy protocol, etc. + // core.v3.CidrRange remote_ip = 11; - // A header (or pseudo-header such as :path or :method) on the incoming HTTP - // request. Only available for HTTP request. Note: the pseudo-header :path - // includes the query and fragment string. Use the ``url_path`` field if you - // want to match the URL path without the query and fragment string. + // A header (or pseudo-header such as ``:path`` or ``:method``) on the incoming HTTP request. Only available + // for HTTP request. + // + // .. note:: + // + // The pseudo-header ``:path`` includes the query and fragment string. Use the ``url_path`` field if you + // want to match the URL path without the query and fragment string. + // route.v3.HeaderMatcher header = 6; // A URL path on the incoming HTTP request. Only available for HTTP. type.matcher.v3.PathMatcher url_path = 9; - // Metadata that describes additional information about the principal. - type.matcher.v3.MetadataMatcher metadata = 7; + // Metadata that describes additional information about the principal. This field is deprecated; please use + // :ref:`sourced_metadata` instead. + type.matcher.v3.MetadataMatcher metadata = 7 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Identifies the principal using a filter state object. type.matcher.v3.FilterStateMatcher filter_state = 12; @@ -366,6 +432,14 @@ message Principal { // ``not_id`` would match, this principal would not match. Conversely, if the // value of ``not_id`` would not match, this principal would match. Principal not_id = 8; + + // Matches against metadata from either dynamic state or route configuration. Preferred over the + // ``metadata`` field as it provides more flexibility in metadata source selection. + SourcedMetadata sourced_metadata = 13; + + // Extension for configuring custom principals for RBAC. + // [#extension-category: envoy.rbac.principals] + core.v3.TypedExtensionConfig custom = 14; } } @@ -377,7 +451,7 @@ message Action { // The action to take if the matcher matches. Every action either allows or denies a request, // and can also carry out action-specific operations. // - // Actions: + // **Actions:** // // * ``ALLOW``: If the request gets matched on ALLOW, it is permitted. // * ``DENY``: If the request gets matched on DENY, it is not permitted. @@ -386,7 +460,7 @@ message Action { // ``envoy.common`` will be set to the value ``true``. // * If the request cannot get matched, it will fallback to ``DENY``. // - // Log behavior: + // **Log behavior:** // // If the RBAC matcher contains at least one LOG action, the dynamic // metadata key ``access_log_hint`` will be set based on if the request diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto index c4d507d22b0..5bd909f34c3 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto @@ -23,7 +23,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // * Routing :ref:`architecture overview ` // * HTTP :ref:`router filter ` -// [#next-free-field: 18] +// [#next-free-field: 19] message RouteConfiguration { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.RouteConfiguration"; @@ -129,10 +129,17 @@ message RouteConfiguration { // By default, port in :authority header (if any) is used in host matching. // With this option enabled, Envoy will ignore the port number in the :authority header (if any) when picking VirtualHost. - // NOTE: this option will not strip the port number (if any) contained in route config - // :ref:`envoy_v3_api_msg_config.route.v3.VirtualHost`.domains field. + // + // .. note:: + // This option will not strip the port number (if any) contained in route config + // :ref:`envoy_v3_api_msg_config.route.v3.VirtualHost`.domains field. bool ignore_port_in_host_matching = 14; + // Normally, virtual host matching is done using the :authority (or + // Host: in HTTP < 2) HTTP header. Setting this will instead, use a + // different HTTP header for this purpose. + string vhost_header = 18; + // Ignore path-parameters in path-matching. // Before RFC3986, URI were like(RFC1808): :///;?# // Envoy by default takes ":path" as ";". diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto index 7e2ff33da5c..4587ef10487 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto @@ -2,9 +2,12 @@ syntax = "proto3"; package envoy.config.route.v3; +import "envoy/config/common/mutation_rules/v3/mutation_rules.proto"; import "envoy/config/core/v3/base.proto"; import "envoy/config/core/v3/extension.proto"; import "envoy/config/core/v3/proxy_protocol.proto"; +import "envoy/config/core/v3/substitution_format_string.proto"; +import "envoy/type/matcher/v3/filter_state.proto"; import "envoy/type/matcher/v3/metadata.proto"; import "envoy/type/matcher/v3/regex.proto"; import "envoy/type/matcher/v3/string.proto"; @@ -17,7 +20,6 @@ import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; -import "xds/annotations/v3/status.proto"; import "xds/type/matcher/v3/matcher.proto"; import "envoy/annotations/deprecation.proto"; @@ -41,7 +43,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // host header. This allows a single listener to service multiple top level domain path trees. Once // a virtual host is selected based on the domain, the routes are processed in order to see which // upstream cluster to route to or whether to perform a redirect. -// [#next-free-field: 25] +// [#next-free-field: 26] message VirtualHost { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.VirtualHost"; @@ -78,7 +80,7 @@ message VirtualHost { // .. note:: // // The wildcard will not match the empty string. - // e.g. ``*-bar.foo.com`` will match ``baz-bar.foo.com`` but not ``-bar.foo.com``. + // For example, ``*-bar.foo.com`` will match ``baz-bar.foo.com`` but not ``-bar.foo.com``. // The longest wildcards match first. // Only a single virtual host in the entire route configuration can match on ``*``. A domain // must be unique across all virtual hosts or the config will fail to load. @@ -92,13 +94,12 @@ message VirtualHost { // The list of routes that will be matched, in order, for incoming requests. // The first route that matches will be used. // Only one of this and ``matcher`` can be specified. - repeated Route routes = 3; + repeated Route routes = 3 [(udpa.annotations.field_migrate).oneof_promotion = "route_selection"]; - // [#next-major-version: This should be included in a oneof with routes wrapped in a message.] // The match tree to use when resolving route actions for incoming requests. Only one of this and ``routes`` // can be specified. xds.type.matcher.v3.Matcher matcher = 21 - [(xds.annotations.v3.field_status).work_in_progress = true]; + [(udpa.annotations.field_migrate).oneof_promotion = "route_selection"]; // Specifies the type of TLS enforcement the virtual host expects. If this option is not // specified, there is no TLS requirement for the virtual host. @@ -156,7 +157,7 @@ message VirtualHost { // This field can be used to provide virtual host level per filter config. The key should match the // :ref:`filter config name // `. - // See :ref:`Http filter route specific config ` + // See :ref:`HTTP filter route-specific config ` // for details. // [#comment: An entry's value may be wrapped in a // :ref:`FilterConfig` @@ -167,7 +168,10 @@ message VirtualHost { // ` header should be included // in the upstream request. Setting this option will cause it to override any existing header // value, so in the case of two Envoys on the request path with this option enabled, the upstream - // will see the attempt count as perceived by the second Envoy. Defaults to false. + // will see the attempt count as perceived by the second Envoy. + // + // Defaults to ``false``. + // // This header is unaffected by the // :ref:`suppress_envoy_headers // ` flag. @@ -179,7 +183,10 @@ message VirtualHost { // ` header should be included // in the downstream response. Setting this option will cause the router to override any existing header // value, so in the case of two Envoys on the request path with this option enabled, the downstream - // will see the attempt count as perceived by the Envoy closest upstream from itself. Defaults to false. + // will see the attempt count as perceived by the Envoy closest upstream from itself. + // + // Defaults to ``false``. + // // This header is unaffected by the // :ref:`suppress_envoy_headers // ` flag. @@ -187,29 +194,56 @@ message VirtualHost { // Indicates the retry policy for all routes in this virtual host. Note that setting a // route level entry will take precedence over this config and it'll be treated - // independently (e.g.: values are not inherited). + // independently (e.g., values are not inherited). RetryPolicy retry_policy = 16; // [#not-implemented-hide:] // Specifies the configuration for retry policy extension. Note that setting a route level entry - // will take precedence over this config and it'll be treated independently (e.g.: values are not + // will take precedence over this config and it'll be treated independently (e.g., values are not // inherited). :ref:`Retry policy ` should not be // set if this field is used. google.protobuf.Any retry_policy_typed_config = 20; // Indicates the hedge policy for all routes in this virtual host. Note that setting a // route level entry will take precedence over this config and it'll be treated - // independently (e.g.: values are not inherited). + // independently (e.g., values are not inherited). HedgePolicy hedge_policy = 17; // Decides whether to include the :ref:`x-envoy-is-timeout-retry ` - // request header in retries initiated by per try timeouts. + // request header in retries initiated by per-try timeouts. bool include_is_timeout_retry_header = 23; - // The maximum bytes which will be buffered for retries and shadowing. - // If set and a route-specific limit is not set, the bytes actually buffered will be the minimum - // value of this and the listener per_connection_buffer_limit_bytes. - google.protobuf.UInt32Value per_request_buffer_limit_bytes = 18; + // The maximum bytes which will be buffered for retries and shadowing. If set, the bytes actually buffered will be + // the minimum value of this and the listener ``per_connection_buffer_limit_bytes``. + // + // .. attention:: + // + // This field has been deprecated. Please use :ref:`request_body_buffer_limit + // ` instead. + // Only one of ``per_request_buffer_limit_bytes`` and ``request_body_buffer_limit`` could be set. + google.protobuf.UInt32Value per_request_buffer_limit_bytes = 18 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // The maximum bytes which will be buffered for request bodies to support large request body + // buffering beyond the ``per_connection_buffer_limit_bytes``. + // + // This limit is specifically for the request body buffering and allows buffering larger payloads while maintaining + // flow control. + // + // Buffer limit precedence (from highest to lowest priority): + // + // 1. If ``request_body_buffer_limit`` is set, then ``request_body_buffer_limit`` will be used. + // 2. If :ref:`per_request_buffer_limit_bytes ` + // is set but ``request_body_buffer_limit`` is not, then ``min(per_request_buffer_limit_bytes, per_connection_buffer_limit_bytes)`` + // will be used. + // 3. If neither is set, then ``per_connection_buffer_limit_bytes`` will be used. + // + // For flow control chunk sizes, ``min(per_connection_buffer_limit_bytes, 16KB)`` will be used. + // + // Only one of :ref:`per_request_buffer_limit_bytes ` + // and ``request_body_buffer_limit`` could be set. + google.protobuf.UInt64Value request_body_buffer_limit = 25 + [(validate.rules).message = {required: false}]; // Specify a set of default request mirroring policies for every route under this virtual host. // It takes precedence over the route config mirror policy entirely. @@ -245,7 +279,7 @@ message RouteList { // // Envoy supports routing on HTTP method via :ref:`header matching // `. -// [#next-free-field: 20] +// [#next-free-field: 21] message Route { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.Route"; @@ -298,7 +332,7 @@ message Route { // This field can be used to provide route specific per filter config. The key should match the // :ref:`filter config name // `. - // See :ref:`Http filter route specific config ` + // See :ref:`HTTP filter route-specific config ` // for details. // [#comment: An entry's value may be wrapped in a // :ref:`FilterConfig` @@ -342,7 +376,14 @@ message Route { // The maximum bytes which will be buffered for retries and shadowing. // If set, the bytes actually buffered will be the minimum value of this and the // listener per_connection_buffer_limit_bytes. - google.protobuf.UInt32Value per_request_buffer_limit_bytes = 16; + // + // .. attention:: + // + // This field has been deprecated. Please use :ref:`request_body_buffer_limit + // ` instead. + // Only one of ``per_request_buffer_limit_bytes`` and ``request_body_buffer_limit`` may be set. + google.protobuf.UInt32Value per_request_buffer_limit_bytes = 16 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // The human readable prefix to use when emitting statistics for this endpoint. // The statistics are rooted at vhost..route.. @@ -356,8 +397,27 @@ message Route { // // We do not recommend setting up a stat prefix for // every application endpoint. This is both not easily maintainable and - // statistics use a non-trivial amount of memory(approximately 1KiB per route). + // statistics use a non-trivial amount of memory (approximately 1KiB per route). string stat_prefix = 19; + + // The maximum bytes which will be buffered for request bodies to support large request body + // buffering beyond the ``per_connection_buffer_limit_bytes``. + // + // This limit is specifically for the request body buffering and allows buffering larger payloads while maintaining + // flow control. + // + // Buffer limit precedence (from highest to lowest priority): + // + // 1. If ``request_body_buffer_limit`` is set: use ``request_body_buffer_limit`` + // 2. If :ref:`per_request_buffer_limit_bytes ` + // is set but ``request_body_buffer_limit`` is not: use ``min(per_request_buffer_limit_bytes, per_connection_buffer_limit_bytes)`` + // 3. If neither is set: use ``per_connection_buffer_limit_bytes`` + // + // For flow control chunk sizes, use ``min(per_connection_buffer_limit_bytes, 16KB)``. + // + // Only one of :ref:`per_request_buffer_limit_bytes ` + // and ``request_body_buffer_limit`` may be set. + google.protobuf.UInt64Value request_body_buffer_limit = 20; } // Compared to the :ref:`cluster ` field that specifies a @@ -366,6 +426,7 @@ message Route { // multiple upstream clusters along with weights that indicate the percentage of // traffic to be forwarded to each cluster. The router selects an upstream cluster based on the // weights. +// [#next-free-field: 6] message WeightedCluster { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.WeightedCluster"; @@ -453,7 +514,7 @@ message WeightedCluster { // This field can be used to provide weighted cluster specific per filter config. The key should match the // :ref:`filter config name // `. - // See :ref:`Http filter route specific config ` + // See :ref:`HTTP filter route-specific config ` // for details. // [#comment: An entry's value may be wrapped in a // :ref:`FilterConfig` @@ -496,12 +557,18 @@ message WeightedCluster { // the process for the consistency. And the value is a unsigned number between 0 and UINT64_MAX. string header_name = 4 [(validate.rules).string = {well_known_regex: HTTP_HEADER_NAME strict: false}]; + + // When set to true, the hash policies will be used to generate the random value for weighted cluster selection. + // This could ensure consistent cluster picking across multiple proxy levels for weighted traffic. + google.protobuf.BoolValue use_hash_policy = 5; } } // Configuration for a cluster specifier plugin. message ClusterSpecifierPlugin { // The name of the plugin and its opaque configuration. + // + // [#extension-category: envoy.router.cluster_specifier_plugin] core.v3.TypedExtensionConfig extension = 1 [(validate.rules).message = {required: true}]; // If is_optional is not set or is set to false and the plugin defined by this message is not a @@ -512,7 +579,7 @@ message ClusterSpecifierPlugin { bool is_optional = 2; } -// [#next-free-field: 16] +// [#next-free-field: 18] message RouteMatch { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RouteMatch"; @@ -570,7 +637,7 @@ message RouteMatch { // // [#next-major-version: In the v3 API we should redo how path specification works such // that we utilize StringMatcher, and additionally have consistent options around whether we - // strip query strings, do a case sensitive match, etc. In the interim it will be too disruptive + // strip query strings, do a case-sensitive match, etc. In the interim it will be too disruptive // to deprecate the existing options. We should even consider whether we want to do away with // path_specifier entirely and just rely on a set of header matchers which can already match // on :path, etc. The issue with that is it is unclear how to generically deal with query string @@ -602,7 +669,7 @@ message RouteMatch { core.v3.TypedExtensionConfig path_match_policy = 15; } - // Indicates that prefix/path matching should be case sensitive. The default + // Indicates that prefix/path matching should be case-sensitive. The default // is true. Ignored for safe_regex matching. google.protobuf.BoolValue case_sensitive = 4; @@ -642,14 +709,19 @@ message RouteMatch { // // If query parameters are used to pass request message fields when // `grpc_json_transcoder `_ - // is used, the transcoded message fields maybe different. The query parameters are - // url encoded, but the message fields are not. For example, if a query + // is used, the transcoded message fields may be different. The query parameters are + // URL-encoded, but the message fields are not. For example, if a query // parameter is "foo%20bar", the message field will be "foo bar". repeated QueryParameterMatcher query_parameters = 7; + // Specifies a set of cookies on which the route should match. The router parses the ``Cookie`` + // header and evaluates the named cookie against each matcher. If the number of specified cookie + // matchers is nonzero, they all must match for the route to be selected. + repeated CookieMatcher cookies = 17; + // If specified, only gRPC requests will be matched. The router will check - // that the content-type header has a application/grpc or one of the various - // application/grpc+ values. + // that the ``Content-Type`` header has ``application/grpc`` or one of the various + // ``application/grpc+`` values. GrpcRouteMatchOptions grpc = 8; // If specified, the client tls context will be matched against the defined @@ -663,6 +735,12 @@ message RouteMatch { // If the number of specified dynamic metadata matchers is nonzero, they all must match the // dynamic metadata for a match to occur. repeated type.matcher.v3.MetadataMatcher dynamic_metadata = 13; + + // Specifies a set of filter state matchers on which the route should match. + // The router will check the filter state against all the specified filter state matchers. + // If the number of specified filter state matchers is nonzero, they all must match the + // filter state for a match to occur. + repeated type.matcher.v3.FilterStateMatcher filter_state = 16; } // Cors policy configuration. @@ -729,11 +807,11 @@ message CorsPolicy { google.protobuf.BoolValue allow_private_network_access = 12; // Specifies if preflight requests not matching the configured allowed origin should be forwarded - // to the upstream. Default is true. + // to the upstream. Default is ``true``. google.protobuf.BoolValue forward_not_matching_preflights = 13; } -// [#next-free-field: 42] +// [#next-free-field: 46] message RouteAction { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RouteAction"; @@ -772,8 +850,8 @@ message RouteAction { // // .. note:: // - // Shadowing doesn't support Http CONNECT and upgrades. - // [#next-free-field: 7] + // Shadowing doesn't support HTTP CONNECT and upgrades. + // [#next-free-field: 9] message RequestMirrorPolicy { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RouteAction.RequestMirrorPolicy"; @@ -817,11 +895,30 @@ message RouteAction { // value, the request will be mirrored. core.v3.RuntimeFractionalPercent runtime_fraction = 3; - // Determines if the trace span should be sampled. Defaults to true. + // Specifies whether the trace span for the shadow request should be sampled. If this field is not explicitly set, + // the shadow request will inherit the sampling decision of its parent span. This ensures consistency with the trace + // sampling policy of the original request and prevents oversampling, especially in scenarios where runtime sampling + // is disabled. google.protobuf.BoolValue trace_sampled = 4; - // Disables appending the ``-shadow`` suffix to the shadowed ``Host`` header. Defaults to ``false``. + // Disables appending the ``-shadow`` suffix to the shadowed ``Host`` header. + // + // Defaults to ``false``. bool disable_shadow_host_suffix_append = 6; + + // Specifies a list of header mutations that should be applied to each mirrored request. + // Header mutations are applied in the order they are specified. For more information, including + // details on header value syntax, see the documentation on :ref:`custom request headers + // `. + repeated common.mutation_rules.v3.HeaderMutation request_headers_mutations = 7 + [(validate.rules).repeated = {max_items: 1000}]; + + // Indicates that during mirroring, the host header will be swapped with this value. + // :ref:`disable_shadow_host_suffix_append + // ` + // is implicitly enabled if this field is set. + string host_rewrite_literal = 8 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; } // Specifies the route's hashing policy if the upstream cluster uses a hashing :ref:`load balancer @@ -983,13 +1080,15 @@ message RouteAction { bool allow_post = 2; } - // The case-insensitive name of this upgrade, e.g. "websocket". + // The case-insensitive name of this upgrade, for example, "websocket". // For each upgrade type present in upgrade_configs, requests with // Upgrade: [upgrade_type] will be proxied upstream. string upgrade_type = 1 [(validate.rules).string = {min_len: 1 well_known_regex: HTTP_HEADER_VALUE strict: false}]; - // Determines if upgrades are available on this route. Defaults to true. + // Determines if upgrades are available on this route. + // + // Defaults to ``true``. google.protobuf.BoolValue enabled = 2; // Configuration for sending data upstream as a raw data payload. This is used for @@ -1088,9 +1187,11 @@ message RouteAction { // place the original path before rewrite into the :ref:`x-envoy-original-path // ` header. // - // Only one of :ref:`regex_rewrite ` + // Only one of :ref:`regex_rewrite `, // :ref:`path_rewrite_policy `, - // or :ref:`prefix_rewrite ` may be specified. + // :ref:`path_rewrite `, + // or :ref:`prefix_rewrite ` + // may be specified. // // .. attention:: // @@ -1126,8 +1227,9 @@ message RouteAction { // ` header. // // Only one of :ref:`regex_rewrite `, - // :ref:`prefix_rewrite `, or - // :ref:`path_rewrite_policy `] + // :ref:`path_rewrite_policy `, + // :ref:`path_rewrite `, + // or :ref:`prefix_rewrite ` // may be specified. // // Examples using Google's `RE2 `_ engine: @@ -1151,12 +1253,48 @@ message RouteAction { // [#extension-category: envoy.path.rewrite] core.v3.TypedExtensionConfig path_rewrite_policy = 41; + // Rewrites the whole path (without query parameters) with the given path value. + // The router filter will + // place the original path before rewrite into the :ref:`x-envoy-original-path + // ` header. + // + // Only one of :ref:`regex_rewrite `, + // :ref:`path_rewrite_policy `, + // :ref:`path_rewrite `, + // or :ref:`prefix_rewrite ` + // may be specified. + // + // The :ref:`substitution format specifier ` could be applied here. + // For example, with the following config: + // + // .. code-block:: yaml + // + // path_rewrite: "/new_path_prefix%REQ(custom-path-header-name)%" + // + // Would rewrite the path to ``/new_path_prefix/some_value`` given the header + // ``custom-path-header-name: some_value``. If the header is not present, the path will be + // rewritten to ``/new_path_prefix``. + // + // + // If the final output of the path rewrite is empty, then the update will be ignored and the + // original path will be preserved. + string path_rewrite = 45; + + // If one of the host rewrite specifiers is set and the + // :ref:`suppress_envoy_headers + // ` flag is not + // set to true, the router filter will place the original host header value before + // rewriting into the :ref:`x-envoy-original-host + // ` header. + // + // And if the + // :ref:`append_x_forwarded_host ` + // is set to true, the original host value will also be appended to the + // :ref:`config_http_conn_man_headers_x-forwarded-host` header. + // oneof host_rewrite_specifier { // Indicates that during forwarding, the host header will be swapped with - // this value. Using this option will append the - // :ref:`config_http_conn_man_headers_x-forwarded-host` header if - // :ref:`append_x_forwarded_host ` - // is set. + // this value. string host_rewrite_literal = 6 [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; @@ -1166,18 +1304,12 @@ message RouteAction { // type ``strict_dns`` or ``logical_dns``, // or when :ref:`hostname ` // field is not empty. Setting this to true with other cluster types - // has no effect. Using this option will append the - // :ref:`config_http_conn_man_headers_x-forwarded-host` header if - // :ref:`append_x_forwarded_host ` - // is set. + // has no effect. google.protobuf.BoolValue auto_host_rewrite = 7; // Indicates that during forwarding, the host header will be swapped with the content of given // downstream or :ref:`custom ` header. - // If header value is empty, host header is left intact. Using this option will append the - // :ref:`config_http_conn_man_headers_x-forwarded-host` header if - // :ref:`append_x_forwarded_host ` - // is set. + // If header value is empty, host header is left intact. // // .. attention:: // @@ -1193,10 +1325,6 @@ message RouteAction { // Indicates that during forwarding, the host header will be swapped with // the result of the regex substitution executed on path value with query and fragment removed. // This is useful for transitioning variable content between path segment and subdomain. - // Using this option will append the - // :ref:`config_http_conn_man_headers_x-forwarded-host` header if - // :ref:`append_x_forwarded_host ` - // is set. // // For example with the following config: // @@ -1210,6 +1338,25 @@ message RouteAction { // // Would rewrite the host header to ``envoyproxy.io`` given the path ``/envoyproxy.io/some/path``. type.matcher.v3.RegexMatchAndSubstitute host_rewrite_path_regex = 35; + + // Rewrites the host header with the value of this field. The router filter will + // place the original host header value before rewriting into the :ref:`x-envoy-original-host + // ` header. + // + // The :ref:`substitution format specifier ` could be applied here. + // For example, with the following config: + // + // .. code-block:: yaml + // + // host_rewrite: "prefix-%REQ(custom-host-header-name)%" + // + // Would rewrite the host header to ``prefix-some_value`` given the header + // ``custom-host-header-name: some_value``. If the header is not present, the host header will + // be rewritten to an value of ``prefix-``. + // + // If the final output of the host rewrite is empty, then the update will be ignored and the + // original host header will be preserved. + string host_rewrite = 44; } // If set, then a host rewrite action (one of @@ -1256,8 +1403,28 @@ message RouteAction { // If the :ref:`overload action ` "envoy.overload_actions.reduce_timeouts" // is configured, this timeout is scaled according to the value for // :ref:`HTTP_DOWNSTREAM_STREAM_IDLE `. + // + // This timeout may also be used in place of ``flush_timeout`` in very specific cases. See the + // documentation for ``flush_timeout`` for more details. google.protobuf.Duration idle_timeout = 24; + // Specifies the codec stream flush timeout for the route. + // + // If not specified, the first preference is the global :ref:`stream_flush_timeout + // `, + // but only if explicitly configured. + // + // If neither the explicit HCM-wide flush timeout nor this route-specific flush timeout is configured, + // the route's stream idle timeout is reused for this timeout. This is for + // backwards compatibility since both behaviors were historically controlled by the one timeout. + // + // If the route also does not have an idle timeout configured, the global :ref:`stream_idle_timeout + // `. used, again + // for backwards compatibility. That timeout defaults to 5 minutes. + // + // A value of 0 via any of the above paths will completely disable the timeout for a given route. + google.protobuf.Duration flush_timeout = 42; + // Specifies how to send request over TLS early data. // If absent, allows `safe HTTP requests `_ to be sent on early data. // [#extension-category: envoy.route.early_data_policy] @@ -1265,13 +1432,13 @@ message RouteAction { // Indicates that the route has a retry policy. Note that if this is set, // it'll take precedence over the virtual host level retry policy entirely - // (e.g.: policies are not merged, most internal one becomes the enforced policy). + // (e.g., policies are not merged, the most internal one becomes the enforced policy). RetryPolicy retry_policy = 9; // [#not-implemented-hide:] // Specifies the configuration for retry policy extension. Note that if this is set, it'll take - // precedence over the virtual host level retry policy entirely (e.g.: policies are not merged, - // most internal one becomes the enforced policy). :ref:`Retry policy ` + // precedence over the virtual host level retry policy entirely (e.g., policies are not merged, + // the most internal one becomes the enforced policy). :ref:`Retry policy ` // should not be set if this field is used. google.protobuf.Any retry_policy_typed_config = 33; @@ -1292,7 +1459,9 @@ message RouteAction { // :ref:`rate_limits ` are not applied to the // request. // - // This field is deprecated. Please use :ref:`vh_rate_limits ` + // .. attention:: + // + // This field is deprecated. Please use :ref:`vh_rate_limits ` google.protobuf.BoolValue include_vh_rate_limits = 14 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; @@ -1386,7 +1555,7 @@ message RouteAction { // Indicates that the route has a hedge policy. Note that if this is set, // it'll take precedence over the virtual host level hedge policy entirely - // (e.g.: policies are not merged, most internal one becomes the enforced policy). + // (e.g., policies are not merged, the most internal one becomes the enforced policy). HedgePolicy hedge_policy = 27; // Specifies the maximum stream duration for this route. @@ -1520,7 +1689,9 @@ message RetryPolicy { // Specifies the maximum back off interval that Envoy will allow. If a reset // header contains an interval longer than this then it will be discarded and - // the next header will be tried. Defaults to 300 seconds. + // the next header will be tried. + // + // Defaults to 300 seconds. google.protobuf.Duration max_interval = 2 [(validate.rules).duration = {gt {}}]; } @@ -1549,7 +1720,7 @@ message RetryPolicy { google.protobuf.Duration per_try_timeout = 3; // Specifies an upstream idle timeout per retry attempt (including the initial attempt). This - // parameter is optional and if absent there is no per try idle timeout. The semantics of the per + // parameter is optional and if absent there is no per-try idle timeout. The semantics of the per- // try idle timeout are similar to the // :ref:`route idle timeout ` and // :ref:`stream idle timeout @@ -1624,12 +1795,14 @@ message HedgePolicy { // Specifies the number of initial requests that should be sent upstream. // Must be at least 1. + // // Defaults to 1. // [#not-implemented-hide:] google.protobuf.UInt32Value initial_requests = 1 [(validate.rules).uint32 = {gte: 1}]; // Specifies a probability that an additional upstream request should be sent // on top of what is specified by initial_requests. + // // Defaults to 0. // [#not-implemented-hide:] type.v3.FractionalPercent additional_request_chance = 2; @@ -1639,14 +1812,16 @@ message HedgePolicy { // The first request to complete successfully will be the one returned to the caller. // // * At any time, a successful response (i.e. not triggering any of the retry-on conditions) would be returned to the client. - // * Before per-try timeout, an error response (per retry-on conditions) would be retried immediately or returned ot the client + // * Before per-try timeout, an error response (per retry-on conditions) would be retried immediately or returned to the client // if there are no more retries left. // * After per-try timeout, an error response would be discarded, as a retry in the form of a hedged request is already in progress. // - // Note: For this to have effect, you must have a :ref:`RetryPolicy ` that retries at least - // one error code and specifies a maximum number of retries. + // .. note:: + // + // For this to have effect, you must have a :ref:`RetryPolicy ` that retries at least + // one error code and specifies a maximum number of retries. // - // Defaults to false. + // Defaults to ``false``. bool hedge_on_per_try_timeout = 3; } @@ -1773,6 +1948,12 @@ message DirectResponseAction { // :ref:`envoy_v3_api_msg_config.route.v3.Route`, :ref:`envoy_v3_api_msg_config.route.v3.RouteConfiguration` or // :ref:`envoy_v3_api_msg_config.route.v3.VirtualHost`. core.v3.DataSource body = 2; + + // Specifies a format string for the response body. If present, the contents of + // ``body_format`` will be formatted and used as the response body, where the + // contents of ``body`` (may be empty) will be passed as the variable ``%LOCAL_REPLY_BODY%``. + // If neither are provided, no body is included in the generated response. + core.v3.SubstitutionFormatString body_format = 3; } // [#not-implemented-hide:] @@ -1792,10 +1973,11 @@ message Decorator { // ` header. string operation = 1 [(validate.rules).string = {min_len: 1}]; - // Whether the decorated details should be propagated to the other party. The default is true. + // Whether the decorated details should be propagated to the other party. The default is ``true``. google.protobuf.BoolValue propagate = 2; } +// [#next-free-field: 7] message Tracing { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.Tracing"; @@ -1831,6 +2013,34 @@ message Tracing { // each in the HTTP connection manager and the route level, the one configured here takes // priority. repeated type.tracing.v3.CustomTag custom_tags = 4; + + // The operation name of the span which will be used for tracing. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + // + // This field will take precedence over and make following settings ineffective: + // + // * :ref:`route decorator `. + // * :ref:`x-envoy-decorator-operation `. + // * :ref:`HCM tracing operation + // `. + string operation = 5; + + // The operation name of the upstream span which will be used for tracing. + // This only takes effect when ``spawn_upstream_span`` is set to true and the upstream + // span is created. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + // + // This field will take precedence over and make following settings ineffective: + // + // * :ref:`HCM tracing upstream operation + // ` + string upstream_operation = 6; } // A virtual cluster is a way of specifying a regex matching rule against @@ -1870,10 +2080,11 @@ message VirtualCluster { // Global rate limiting :ref:`architecture overview `. // Also applies to Local rate limiting :ref:`using descriptors `. +// [#next-free-field: 7] message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit"; - // [#next-free-field: 12] + // [#next-free-field: 13] message Action { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit.Action"; @@ -1930,9 +2141,48 @@ message RateLimit { // The key to use in the descriptor entry. string descriptor_key = 2 [(validate.rules).string = {min_len: 1}]; - // If set to true, Envoy skips the descriptor while calling rate limiting service - // when header is not present in the request. By default it skips calling the - // rate limiting service if this header is not present in the request. + // Controls the behavior when the specified header is not present in the request. + // + // If set to ``false`` (default): + // + // * Envoy does **NOT** call the rate limiting service for this descriptor. + // * Useful if the header is optional and you prefer to skip rate limiting when it's absent. + // + // If set to ``true``: + // + // * Envoy calls the rate limiting service but omits this descriptor if the header is missing. + // * Useful if you want Envoy to enforce rate limiting even when the header is not present. + // + bool skip_if_absent = 3; + } + + // The following descriptor entry is appended when a query parameter contains a key that matches the + // ``query_parameter_name``: + // + // .. code-block:: cpp + // + // ("", "") + message QueryParameters { + // The name of the query parameter to use for rate limiting. Value of this query parameter is used to populate + // the value of the descriptor entry for the descriptor_key. + string query_parameter_name = 1 [(validate.rules).string = {min_len: 1}]; + + // The key to use when creating the rate limit descriptor entry. This descriptor key will be used to identify the + // rate limit rule in the rate limiting service. + string descriptor_key = 2 [(validate.rules).string = {min_len: 1}]; + + // Controls the behavior when the specified query parameter is not present in the request. + // + // If set to ``false`` (default): + // + // * Envoy does **NOT** call the rate limiting service for this descriptor. + // * Useful if the query parameter is optional and you prefer to skip rate limiting when it's absent. + // + // If set to ``true``: + // + // * Envoy calls the rate limiting service but omits this descriptor if the query parameter is missing. + // * Useful if you want Envoy to enforce rate limiting even when the query parameter is not present. + // bool skip_if_absent = 3; } @@ -1955,14 +2205,18 @@ message RateLimit { // ("masked_remote_address", "") message MaskedRemoteAddress { // Length of prefix mask len for IPv4 (e.g. 0, 32). + // // Defaults to 32 when unset. + // // For example, trusted address from x-forwarded-for is ``192.168.1.1``, // the descriptor entry is ("masked_remote_address", "192.168.1.1/32"); // if mask len is 24, the descriptor entry is ("masked_remote_address", "192.168.1.0/24"). google.protobuf.UInt32Value v4_prefix_mask_len = 1 [(validate.rules).uint32 = {lte: 32}]; // Length of prefix mask len for IPv6 (e.g. 0, 128). + // // Defaults to 128 when unset. + // // For example, trusted address from x-forwarded-for is ``2001:abcd:ef01:2345:6789:abcd:ef01:234``, // the descriptor entry is ("masked_remote_address", "2001:abcd:ef01:2345:6789:abcd:ef01:234/128"); // if mask len is 64, the descriptor entry is ("masked_remote_address", "2001:abcd:ef01:2345::/64"). @@ -1978,9 +2232,40 @@ message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit.Action.GenericKey"; - // The value to use in the descriptor entry. + // Descriptor value of entry. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + // + // .. note:: + // + // Formatter parsing is controlled by the runtime feature flag + // ``envoy.reloadable_features.enable_formatter_for_ratelimit_action_descriptor_value`` + // (disabled by default). + // + // When enabled: The format string can contain multiple valid substitution + // fields. If multiple substitution fields are present, their results will be concatenated + // to form the final descriptor value. If it contains no substitution fields, the value + // will be used as is. If the final concatenated result is empty and ``default_value`` is set, + // the ``default_value`` will be used. If ``default_value`` is not set and the result is + // empty, this descriptor will be skipped and not included in the rate limit call. + // + // When disabled (default): The descriptor_value is used as a literal string without any formatter + // parsing or substitution. + // + // For example, ``static_value`` will be used as is since there are no substitution fields. + // ``%REQ(:method)%`` will be replaced with the HTTP method, and + // ``%REQ(:method)%%REQ(:path)%`` will be replaced with the concatenation of the HTTP method and path. + // ``%CEL(request.headers['user-id'])%`` will use CEL to extract the user ID from request headers. + // string descriptor_value = 1 [(validate.rules).string = {min_len: 1}]; + // An optional value to use if the final concatenated ``descriptor_value`` result is empty. + // Only applicable when formatter parsing is enabled by the runtime feature flag + // ``envoy.reloadable_features.enable_formatter_for_ratelimit_action_descriptor_value`` (disabled by default). + string default_value = 3; + // An optional key to use in the descriptor entry. If not set it defaults // to 'generic_key' as the descriptor key. string descriptor_key = 2; @@ -1991,16 +2276,51 @@ message RateLimit { // .. code-block:: cpp // // ("header_match", "") + // [#next-free-field: 6] message HeaderValueMatch { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit.Action.HeaderValueMatch"; - // The key to use in the descriptor entry. Defaults to ``header_match``. - string descriptor_key = 4; - - // The value to use in the descriptor entry. + // Descriptor value of entry. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + // + // .. note:: + // + // Formatter parsing is controlled by the runtime feature flag + // ``envoy.reloadable_features.enable_formatter_for_ratelimit_action_descriptor_value`` + // (disabled by default). + // + // When enabled: The format string can contain multiple valid substitution + // fields. If multiple substitution fields are present, their results will be concatenated + // to form the final descriptor value. If it contains no substitution fields, the value + // will be used as is. All substitution fields will be evaluated and their results + // concatenated. If the final concatenated result is empty and ``default_value`` is set, + // the ``default_value`` will be used. If ``default_value`` is not set and the result is + // empty, this descriptor will be skipped and not included in the rate limit call. + // + // When disabled (default): The descriptor_value is used as a literal string without any formatter + // parsing or substitution. + // + // For example, ``static_value`` will be used as is since there are no substitution fields. + // ``%REQ(:method)%`` will be replaced with the HTTP method, and + // ``%REQ(:method)%%REQ(:path)%`` will be replaced with the concatenation of the HTTP method and path. + // ``%CEL(request.headers['user-id'])%`` will use CEL to extract the user ID from request headers. + // string descriptor_value = 1 [(validate.rules).string = {min_len: 1}]; + // An optional value to use if the final concatenated ``descriptor_value`` result is empty. + // Only applicable when formatter parsing is enabled by the runtime feature flag + // ``envoy.reloadable_features.enable_formatter_for_ratelimit_action_descriptor_value`` (disabled by default). + string default_value = 5; + + // The key to use in the descriptor entry. + // + // Defaults to ``header_match``. + string descriptor_key = 4; + // If set to true, the action will append a descriptor entry when the // request matches the headers. If set to false, the action will append a // descriptor entry when the request does not match the headers. The @@ -2008,7 +2328,7 @@ message RateLimit { google.protobuf.BoolValue expect_match = 2; // Specifies a set of headers that the rate limit action should match - // on. The action will check the request’s headers against all the + // on. The action will check the request's headers against all the // specified headers in the config. A match will happen if all the // headers in the config are present in the request with the same values // (or based on presence if the value field is not in the config). @@ -2067,9 +2387,19 @@ message RateLimit { // Source of metadata Source source = 4 [(validate.rules).enum = {defined_only: true}]; - // If set to true, Envoy skips the descriptor while calling rate limiting service - // when ``metadata_key`` is empty and ``default_value`` is not set. By default it skips calling the - // rate limiting service in that case. + // Controls the behavior when the specified ``metadata_key`` is empty and ``default_value`` is not set. + // + // If set to ``false`` (default): + // + // * Envoy does **NOT** call the rate limiting service for this descriptor. + // * Useful if the metadata is optional and you prefer to skip rate limiting when it's absent. + // + // If set to ``true``: + // + // * Envoy calls the rate limiting service but omits this descriptor if the ``metadata_key`` is empty and + // ``default_value`` is missing. + // * Useful if you want Envoy to enforce rate limiting even when the metadata is not present. + // bool skip_if_absent = 5; } @@ -2078,13 +2408,48 @@ message RateLimit { // .. code-block:: cpp // // ("query_match", "") + // [#next-free-field: 6] message QueryParameterValueMatch { - // The key to use in the descriptor entry. Defaults to ``query_match``. - string descriptor_key = 4; - - // The value to use in the descriptor entry. + // Descriptor value of entry. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + // + // .. note:: + // + // Formatter parsing is controlled by the runtime feature flag + // ``envoy.reloadable_features.enable_formatter_for_ratelimit_action_descriptor_value`` + // (disabled by default). + // + // When enabled: The format string can contain multiple valid substitution + // fields. If multiple substitution fields are present, their results will be concatenated + // to form the final descriptor value. If it contains no substitution fields, the value + // will be used as is. All substitution fields will be evaluated and their results + // concatenated. If the final concatenated result is empty and ``default_value`` is set, + // the ``default_value`` will be used. If ``default_value`` is not set and the result is + // empty, this descriptor will be skipped and not included in the rate limit call. + // + // When disabled (default): The descriptor_value is used as a literal string without any formatter + // parsing or substitution. + // + // For example, ``static_value`` will be used as is since there are no substitution fields. + // ``%REQ(:method)%`` will be replaced with the HTTP method, and + // ``%REQ(:method)%%REQ(:path)%`` will be replaced with the concatenation of the HTTP method and path. + // ``%CEL(request.headers['user-id'])%`` will use CEL to extract the user ID from request headers. + // string descriptor_value = 1 [(validate.rules).string = {min_len: 1}]; + // An optional value to use if the final concatenated ``descriptor_value`` result is empty. + // Only applicable when formatter parsing is enabled by the runtime feature flag + // ``envoy.reloadable_features.enable_formatter_for_ratelimit_action_descriptor_value`` (disabled by default). + string default_value = 5; + + // The key to use in the descriptor entry. + // + // Defaults to ``query_match``. + string descriptor_key = 4; + // If set to true, the action will append a descriptor entry when the // request matches the headers. If set to false, the action will append a // descriptor entry when the request does not match the headers. The @@ -2092,7 +2457,7 @@ message RateLimit { google.protobuf.BoolValue expect_match = 2; // Specifies a set of query parameters that the rate limit action should match - // on. The action will check the request’s query parameters against all the + // on. The action will check the request's query parameters against all the // specified query parameters in the config. A match will happen if all the // query parameters in the config are present in the request with the same values // (or based on presence if the value field is not in the config). @@ -2112,6 +2477,9 @@ message RateLimit { // Rate limit on request headers. RequestHeaders request_headers = 3; + // Rate limit on query parameters. + QueryParameters query_parameters = 12; + // Rate limit on remote address. RemoteAddress remote_address = 4; @@ -2170,6 +2538,33 @@ message RateLimit { } } + message HitsAddend { + // Fixed number of hits to add to the rate limit descriptor. + // + // One of the ``number`` or ``format`` fields should be set but not both. + google.protobuf.UInt64Value number = 1 [(validate.rules).uint64 = {lte: 1000000000}]; + + // Substitution format string to extract the number of hits to add to the rate limit descriptor. + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here. + // + // .. note:: + // + // The format string must contains only single valid substitution field. If the format string + // not meets the requirement, the configuration will be rejected. + // + // The substitution field should generates a non-negative number or string representation of + // a non-negative number. The value of the non-negative number should be less than or equal + // to 1000000000 like the ``number`` field. If the output of the substitution field not meet + // the requirement, this will be treated as an error and the current descriptor will be ignored. + // + // For example, the ``%BYTES_RECEIVED%`` format string will be replaced with the number of bytes + // received in the request. + // + // One of the ``number`` or ``format`` fields should be set but not both. + string format = 2 [(validate.rules).string = {prefix: "%" suffix: "%" ignore_empty: true}]; + } + // Refers to the stage set in the filter. The rate limit configuration only // applies to filters with the same stage number. The default stage number is // 0. @@ -2177,9 +2572,19 @@ message RateLimit { // .. note:: // // The filter supports a range of 0 - 10 inclusively for stage numbers. + // + // .. note:: + // This is not supported if the rate limit action is configured in the ``typed_per_filter_config`` like + // :ref:`VirtualHost.typed_per_filter_config` or + // :ref:`Route.typed_per_filter_config`, etc. google.protobuf.UInt32Value stage = 1 [(validate.rules).uint32 = {lte: 10}]; // The key to be set in runtime to disable this rate limit configuration. + // + // .. note:: + // This is not supported if the rate limit action is configured in the ``typed_per_filter_config`` like + // :ref:`VirtualHost.typed_per_filter_config` or + // :ref:`Route.typed_per_filter_config`, etc. string disable_key = 2; // A list of actions that are to be applied for this rate limit configuration. @@ -2194,7 +2599,38 @@ message RateLimit { // rate limit configuration. If the override value is invalid or cannot be resolved // from metadata, no override is provided. See :ref:`rate limit override // ` for more information. + // + // .. note:: + // This is not supported if the rate limit action is configured in the ``typed_per_filter_config`` like + // :ref:`VirtualHost.typed_per_filter_config` or + // :ref:`Route.typed_per_filter_config`, etc. Override limit = 4; + + // An optional hits addend to be appended to the descriptor produced by this rate limit + // configuration. + // + // .. note:: + // This is only supported if the rate limit action is configured in the ``typed_per_filter_config`` like + // :ref:`VirtualHost.typed_per_filter_config` or + // :ref:`Route.typed_per_filter_config`, etc. + HitsAddend hits_addend = 5; + + // If true, the rate limit request will be applied when the stream completes. The default value is false. + // This is useful when the rate limit budget needs to reflect the response context that is not available + // on the request path. + // + // For example, let's say the upstream service calculates the usage statistics and returns them in the response body + // and we want to utilize these numbers to apply the rate limit action for the subsequent requests. + // Combined with another filter that can set the desired addend based on the response (e.g. Lua filter), + // this can be used to subtract the usage statistics from the rate limit budget. + // + // A rate limit applied on the stream completion is "fire-and-forget" by nature, and rate limit is not enforced by this config. + // In other words, the current request won't be blocked when this is true, but the budget will be updated for the subsequent + // requests based on the action with this field set to true. Users should ensure that the rate limit is enforced by the actions + // applied on the request path, i.e. the ones with this field set to false. + // + // Currently, this is only supported by the HTTP global rate filter. + bool apply_on_stream_done = 6; } // .. attention:: @@ -2238,14 +2674,20 @@ message HeaderMatcher { // Specifies how the header match will be performed to route the request. oneof header_match_specifier { // If specified, header match will be performed based on the value of the header. - // This field is deprecated. Please use :ref:`string_match `. + // + // .. attention:: + // + // This field is deprecated. Please use :ref:`string_match `. string exact_match = 4 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // If specified, this regex string is a regular expression rule which implies the entire request // header value must match the regex. The rule will not match if only a subsequence of the // request header value matches the regex. - // This field is deprecated. Please use :ref:`string_match `. + // + // .. attention:: + // + // This field is deprecated. Please use :ref:`string_match `. type.matcher.v3.RegexMatcher safe_regex_match = 11 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; @@ -2267,8 +2709,14 @@ message HeaderMatcher { bool present_match = 7; // If specified, header match will be performed based on the prefix of the header value. - // Note: empty prefix is not allowed, please use present_match instead. - // This field is deprecated. Please use :ref:`string_match `. + // + // .. note:: + // + // Empty prefix is not allowed. Please use ``present_match`` instead. + // + // .. attention:: + // + // This field is deprecated. Please use :ref:`string_match `. // // Examples: // @@ -2280,8 +2728,14 @@ message HeaderMatcher { ]; // If specified, header match will be performed based on the suffix of the header value. - // Note: empty suffix is not allowed, please use present_match instead. - // This field is deprecated. Please use :ref:`string_match `. + // + // .. note:: + // + // Empty suffix is not allowed. Please use ``present_match`` instead. + // + // .. attention:: + // + // This field is deprecated. Please use :ref:`string_match `. // // Examples: // @@ -2294,8 +2748,14 @@ message HeaderMatcher { // If specified, header match will be performed based on whether the header value contains // the given value or not. - // Note: empty contains match is not allowed, please use present_match instead. - // This field is deprecated. Please use :ref:`string_match `. + // + // .. note:: + // + // Empty contains match is not allowed. Please use ``present_match`` instead. + // + // .. attention:: + // + // This field is deprecated. Please use :ref:`string_match `. // // Examples: // @@ -2310,7 +2770,9 @@ message HeaderMatcher { type.matcher.v3.StringMatcher string_match = 13; } - // If specified, the match result will be inverted before checking. Defaults to false. + // If specified, the match result will be inverted before checking. + // + // Defaults to ``false``. // // Examples: // @@ -2319,7 +2781,9 @@ message HeaderMatcher { bool invert_match = 8; // If specified, for any header match rule, if the header match rule specified header - // does not exist, this header value will be treated as empty. Defaults to false. + // does not exist, this header value will be treated as empty. + // + // Defaults to ``false``. // // Examples: // @@ -2371,6 +2835,20 @@ message QueryParameterMatcher { } } +// Cookie matching inspects individual name/value pairs parsed from the ``Cookie`` header. +message CookieMatcher { + // Specifies the cookie name to evaluate. + string name = 1 [(validate.rules).string = {min_len: 1 max_bytes: 1024}]; + + // Match the cookie value using :ref:`StringMatcher + // ` semantics. + type.matcher.v3.StringMatcher string_match = 2 [(validate.rules).message = {required: true}]; + + // Invert the match result. If the cookie is not present, the match result is false, so + // ``invert_match`` will cause the matcher to succeed when the cookie is absent. + bool invert_match = 3; +} + // HTTP Internal Redirect :ref:`architecture overview `. // [#next-free-field: 6] message InternalRedirectPolicy { @@ -2396,7 +2874,7 @@ message InternalRedirectPolicy { repeated core.v3.TypedExtensionConfig predicates = 3; // Allow internal redirect to follow a target URI with a different scheme than the value of - // x-forwarded-proto. The default is false. + // x-forwarded-proto. The default is ``false``. bool allow_cross_scheme_redirect = 4; // Specifies a list of headers, by name, to copy from the internal redirect into the subsequent @@ -2436,6 +2914,5 @@ message FilterConfig { // initial route will not be added back to the filter chain because the filter chain is already // created and it is too late to change the chain. // - // This field only make sense for the downstream HTTP filters for now. bool disabled = 3; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/datadog.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/datadog.proto index bed6c8eec36..5359ec74267 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/datadog.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/datadog.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package envoy.config.trace.v3; +import "google/protobuf/duration.proto"; + import "udpa/annotations/migrate.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; @@ -16,6 +18,13 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Datadog tracer] +// Configuration for the Remote Configuration feature. +message DatadogRemoteConfig { + // Frequency at which new configuration updates are queried. + // If no value is provided, the default value is delegated to the Datadog tracing library. + google.protobuf.Duration polling_interval = 1; +} + // Configuration for the Datadog tracer. // [#extension: envoy.tracers.datadog] message DatadogConfig { @@ -31,4 +40,11 @@ message DatadogConfig { // Optional hostname to use when sending spans to the collector_cluster. Useful for collectors // that require a specific hostname. Defaults to :ref:`collector_cluster ` above. string collector_hostname = 3; + + // Enables and configures remote configuration. + // Remote Configuration allows to configure the tracer from Datadog's user interface. + // This feature can drastically increase the number of connections to the Datadog Agent. + // Each tracer regularly polls for configuration updates, and the number of tracers is the product + // of the number of listeners and worker threads. + DatadogRemoteConfig remote_config = 4; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/dynamic_ot.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/dynamic_ot.proto index d2664ef717e..40fe8526a5f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/dynamic_ot.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/dynamic_ot.proto @@ -20,10 +20,10 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Dynamically loadable OpenTracing tracer] -// DynamicOtConfig is used to dynamically load a tracer from a shared library +// DynamicOtConfig was used to dynamically load a tracer from a shared library // that implements the `OpenTracing dynamic loading API // `_. -// [#extension: envoy.tracers.dynamic_ot] +// [#not-implemented-hide:] message DynamicOtConfig { option (udpa.annotations.versioning).previous_message_type = "envoy.config.trace.v2.DynamicOtConfig"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opentelemetry.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opentelemetry.proto index 59028326f22..5260d9bd6af 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opentelemetry.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opentelemetry.proto @@ -6,6 +6,8 @@ import "envoy/config/core/v3/extension.proto"; import "envoy/config/core/v3/grpc_service.proto"; import "envoy/config/core/v3/http_service.proto"; +import "google/protobuf/wrappers.proto"; + import "udpa/annotations/migrate.proto"; import "udpa/annotations/status.proto"; @@ -19,7 +21,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Configuration for the OpenTelemetry tracer. // [#extension: envoy.tracers.opentelemetry] -// [#next-free-field: 6] +// [#next-free-field: 7] message OpenTelemetryConfig { // The upstream gRPC cluster that will receive OTLP traces. // Note that the tracer drops traces if the server does not read data fast enough. @@ -57,4 +59,9 @@ message OpenTelemetryConfig { // See: `OpenTelemetry sampler specification `_ // [#extension-category: envoy.tracers.opentelemetry.samplers] core.v3.TypedExtensionConfig sampler = 5; + + // Envoy caches the span in memory when the OpenTelemetry backend service is temporarily unavailable. + // This field specifies the maximum number of spans that can be cached. If not specified, the + // default is 1024. + google.protobuf.UInt32Value max_cache_size = 6; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/zipkin.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/zipkin.proto index 2d8f3195c31..2364983efc5 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/zipkin.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/zipkin.proto @@ -2,13 +2,14 @@ syntax = "proto3"; package envoy.config.trace.v3; +import "envoy/config/core/v3/http_service.proto"; + import "google/protobuf/wrappers.proto"; import "envoy/annotations/deprecation.proto"; import "udpa/annotations/migrate.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; -import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v3"; option java_outer_classname = "ZipkinProto"; @@ -21,10 +22,22 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Configuration for the Zipkin tracer. // [#extension: envoy.tracers.zipkin] -// [#next-free-field: 8] +// [#next-free-field: 10] message ZipkinConfig { option (udpa.annotations.versioning).previous_message_type = "envoy.config.trace.v2.ZipkinConfig"; + // Available trace context options for handling different trace header formats. + enum TraceContextOption { + // Use B3 headers only (default behavior). + USE_B3 = 0; + + // Enable B3 and W3C dual header support: + // - For downstream: Extract from B3 headers first, fallback to W3C traceparent if B3 is unavailable. + // - For upstream: Inject both B3 and W3C traceparent headers. + // When this option is NOT set, only B3 headers are used for both extraction and injection. + USE_B3_WITH_W3C_PROPAGATION = 1; + } + // Available Zipkin collector endpoint versions. enum CollectorEndpointVersion { // Zipkin API v1, JSON over HTTP. @@ -48,11 +61,23 @@ message ZipkinConfig { } // The cluster manager cluster that hosts the Zipkin collectors. - string collector_cluster = 1 [(validate.rules).string = {min_len: 1}]; + // + // .. note:: + // This field will be deprecated in future releases in favor of + // :ref:`collector_service `. + // + // Either this field or ``collector_service`` must be specified. + string collector_cluster = 1; // The API endpoint of the Zipkin service where the spans will be sent. When // using a standard Zipkin installation. - string collector_endpoint = 2 [(validate.rules).string = {min_len: 1}]; + // + // .. note:: + // This field will be deprecated in future releases in favor of + // :ref:`collector_service `. + // + // Required when using ``collector_cluster``. + string collector_endpoint = 2; // Determines whether a 128bit trace id will be used when creating a new // trace instance. The default value is false, which will result in a 64 bit trace id being used. @@ -67,6 +92,10 @@ message ZipkinConfig { // Optional hostname to use when sending spans to the collector_cluster. Useful for collectors // that require a specific hostname. Defaults to :ref:`collector_cluster ` above. + // + // .. note:: + // This field will be deprecated in future releases in favor of + // :ref:`collector_service `. string collector_hostname = 6; // If this is set to true, then Envoy will be treated as an independent hop in trace chain. A complete span pair will be created for a single @@ -88,4 +117,60 @@ message ZipkinConfig { // Please use that ``spawn_upstream_span`` field to control the span creation. bool split_spans_for_request = 7 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // Determines which trace context format to use for trace header extraction and propagation. + // This controls both downstream request header extraction and upstream request header injection. + // Here is the spec for W3C trace headers: https://www.w3.org/TR/trace-context/ + // The default value is USE_B3 to maintain backward compatibility. + TraceContextOption trace_context_option = 8; + + // HTTP service configuration for the Zipkin collector. + // When specified, this configuration takes precedence over the legacy fields: + // collector_cluster, collector_endpoint, and collector_hostname. + // This provides a complete HTTP service configuration including cluster, URI, timeout, and headers. + // If not specified, the legacy fields above will be used for backward compatibility. + // + // Required fields when using collector_service: + // + // * ``http_uri.cluster`` - Must be specified and non-empty + // * ``http_uri.uri`` - Must be specified and non-empty + // * ``http_uri.timeout`` - Optional + // + // Full URI Support with Automatic Parsing: + // + // The ``uri`` field supports both path-only and full URI formats: + // + // .. code-block:: yaml + // + // tracing: + // provider: + // name: envoy.tracers.zipkin + // typed_config: + // "@type": type.googleapis.com/envoy.config.trace.v3.ZipkinConfig + // collector_service: + // http_uri: + // # Full URI format - hostname and path are extracted automatically + // uri: "https://zipkin-collector.example.com/api/v2/spans" + // cluster: zipkin + // timeout: 5s + // request_headers_to_add: + // - header: + // key: "X-Custom-Token" + // value: "your-custom-token" + // - header: + // key: "X-Service-ID" + // value: "your-service-id" + // + // URI Parsing Behavior: + // + // * Full URI: ``"https://zipkin-collector.example.com/api/v2/spans"`` + // + // * Hostname: ``zipkin-collector.example.com`` (sets HTTP ``Host`` header) + // * Path: ``/api/v2/spans`` (sets HTTP request path) + // + // * Path only: ``"/api/v2/spans"`` + // + // * Hostname: Uses cluster name as fallback + // * Path: ``/api/v2/spans`` + core.v3.HttpService collector_service = 9; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/data/accesslog/v3/accesslog.proto b/xds/third_party/envoy/src/main/proto/envoy/data/accesslog/v3/accesslog.proto index 2e02f1eb455..da029b7da2e 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/data/accesslog/v3/accesslog.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/data/accesslog/v3/accesslog.proto @@ -109,14 +109,16 @@ message AccessLogCommon { double sample_rate = 1 [(validate.rules).double = {lte: 1.0 gt: 0.0}]; // This field is the remote/origin address on which the request from the user was received. - // Note: This may not be the physical peer. E.g, if the remote address is inferred from for - // example the x-forwarder-for header, proxy protocol, etc. + // + // .. note:: + // This may not be the actual peer address. For example, it might be derived from headers like ``x-forwarded-for``, + // the proxy protocol, or similar sources. config.core.v3.Address downstream_remote_address = 2; // This field is the local/destination address on which the request from the user was received. config.core.v3.Address downstream_local_address = 3; - // If the connection is secure,S this field will contain TLS properties. + // If the connection is secure, this field will contain TLS properties. TLSProperties tls_properties = 4; // The time that Envoy started servicing this request. This is effectively the time that the first @@ -128,7 +130,7 @@ message AccessLogCommon { google.protobuf.Duration time_to_last_rx_byte = 6; // Interval between the first downstream byte received and the first upstream byte sent. There may - // by considerable delta between ``time_to_last_rx_byte`` and this value due to filters. + // be considerable delta between ``time_to_last_rx_byte`` and this value due to filters. // Additionally, the same caveats apply as documented in ``time_to_last_downstream_tx_byte`` about // not accounting for kernel socket buffer time, etc. google.protobuf.Duration time_to_first_upstream_tx_byte = 7; @@ -187,7 +189,7 @@ message AccessLogCommon { // If upstream connection failed due to transport socket (e.g. TLS handshake), provides the // failure reason from the transport socket. The format of this field depends on the configured // upstream transport socket. Common TLS failures are in - // :ref:`TLS trouble shooting `. + // :ref:`TLS troubleshooting `. string upstream_transport_failure_reason = 18; // The name of the route @@ -204,7 +206,7 @@ message AccessLogCommon { map filter_state_objects = 21; // A list of custom tags, which annotate logs with additional information. - // To configure this value, users should configure + // To configure this value, see the documentation for // :ref:`custom_tags `. map custom_tags = 22; @@ -225,40 +227,41 @@ message AccessLogCommon { // This could be any format string that could be used to identify one stream. string stream_id = 26; - // If this log entry is final log entry that flushed after the stream completed or - // intermediate log entry that flushed periodically during the stream. - // There may be multiple intermediate log entries and only one final log entry for each - // long-live stream (TCP connection, long-live HTTP2 stream). - // And if it is necessary, unique ID or identifier can be added to the log entry - // :ref:`stream_id ` to - // correlate all these intermediate log entries and final log entry. + // Indicates whether this log entry is the final entry (flushed after the stream completed) or an intermediate entry + // (flushed periodically during the stream). + // + // For long-lived streams (e.g., TCP connections or long-lived HTTP/2 streams), there may be multiple intermediate + // entries and only one final entry. + // + // If needed, a unique identifier (see :ref:`stream_id `) + // can be used to correlate all intermediate and final log entries for the same stream. // // .. attention:: // - // This field is deprecated in favor of ``access_log_type`` for better indication of the - // type of the access log record. + // This field is deprecated in favor of ``access_log_type``, which provides a clearer indication of the log entry + // type. bool intermediate_log_entry = 27 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // If downstream connection in listener failed due to transport socket (e.g. TLS handshake), provides the // failure reason from the transport socket. The format of this field depends on the configured downstream - // transport socket. Common TLS failures are in :ref:`TLS trouble shooting `. + // transport socket. Common TLS failures are in :ref:`TLS troubleshooting `. string downstream_transport_failure_reason = 28; // For HTTP: Total number of bytes sent to the downstream by the http stream. - // For TCP: Total number of bytes sent to the downstream by the tcp proxy. + // For TCP: Total number of bytes sent to the downstream by the :ref:`TCP Proxy `. uint64 downstream_wire_bytes_sent = 29; // For HTTP: Total number of bytes received from the downstream by the http stream. Envoy over counts sizes of received HTTP/1.1 pipelined requests by adding up bytes of requests in the pipeline to the one currently being processed. - // For TCP: Total number of bytes received from the downstream by the tcp proxy. + // For TCP: Total number of bytes received from the downstream by the :ref:`TCP Proxy `. uint64 downstream_wire_bytes_received = 30; // For HTTP: Total number of bytes sent to the upstream by the http stream. This value accumulates during upstream retries. - // For TCP: Total number of bytes sent to the upstream by the tcp proxy. + // For TCP: Total number of bytes sent to the upstream by the :ref:`TCP Proxy `. uint64 upstream_wire_bytes_sent = 31; // For HTTP: Total number of bytes received from the upstream by the http stream. - // For TCP: Total number of bytes sent to the upstream by the tcp proxy. + // For TCP: Total number of bytes sent to the upstream by the :ref:`TCP Proxy `. uint64 upstream_wire_bytes_received = 32; // The type of the access log, which indicates when the log was recorded. @@ -297,7 +300,7 @@ message ResponseFlags { // Indicates there was no healthy upstream. bool no_healthy_upstream = 2; - // Indicates an there was an upstream request timeout. + // Indicates there was an upstream request timeout. bool upstream_request_timeout = 3; // Indicates local codec level reset was sent on the stream. @@ -358,7 +361,7 @@ message ResponseFlags { // Indicates that a filter configuration is not available. bool no_filter_config_found = 22; - // Indicates that request or connection exceeded the downstream connection duration. + // Indicates that the request or connection exceeded the downstream connection duration. bool duration_timeout = 23; // Indicates there was an HTTP protocol error in the upstream response. @@ -480,7 +483,7 @@ message HTTPRequestProperties { // do not already have a request ID. string request_id = 9; - // Value of the ``X-Envoy-Original-Path`` request header. + // Value of the ``x-envoy-original-path`` request header. string original_path = 10; // Size of the HTTP request headers in bytes. diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/clusters/aggregate/v3/cluster.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/clusters/aggregate/v3/cluster.proto index 4f44ac9cd5c..d23d767f73b 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/clusters/aggregate/v3/cluster.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/clusters/aggregate/v3/cluster.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package envoy.extensions.clusters.aggregate.v3; +import "envoy/config/core/v3/config_source.proto"; + import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -25,3 +27,18 @@ message ClusterConfig { // appear in this list. repeated string clusters = 1 [(validate.rules).repeated = {min_items: 1}]; } + +// Configures an aggregate cluster whose +// :ref:`ClusterConfig ` +// is to be fetched from a separate xDS resource. +// [#extension: envoy.clusters.aggregate_resource] +// [#not-implemented-hide:] +message AggregateClusterResource { + // Configuration source specifier for the ClusterConfig resource. + // Only the aggregated protocol variants are supported; if configured + // otherwise, the cluster resource will be NACKed. + config.core.v3.ConfigSource config_source = 1 [(validate.rules).message = {required: true}]; + + // The name of the ClusterConfig resource to subscribe to. + string resource_name = 2 [(validate.rules).string = {min_len: 1}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/common/matching/v3/extension_matcher.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/common/matching/v3/extension_matcher.proto new file mode 100644 index 00000000000..817cd27a37a --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/common/matching/v3/extension_matcher.proto @@ -0,0 +1,42 @@ +syntax = "proto3"; + +package envoy.extensions.common.matching.v3; + +import "envoy/config/common/matcher/v3/matcher.proto"; +import "envoy/config/core/v3/extension.proto"; + +import "xds/type/matcher/v3/matcher.proto"; + +import "envoy/annotations/deprecation.proto"; +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.common.matching.v3"; +option java_outer_classname = "ExtensionMatcherProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/common/matching/v3;matchingv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Extension matcher] + +// Wrapper around an existing extension that provides an associated matcher. This allows +// decorating an existing extension with a matcher, which can be used to match against +// relevant protocol data. +message ExtensionWithMatcher { + // The associated matcher. This is deprecated in favor of xds_matcher. + config.common.matcher.v3.Matcher matcher = 1 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // The associated matcher. + xds.type.matcher.v3.Matcher xds_matcher = 3; + + // The underlying extension config. + config.core.v3.TypedExtensionConfig extension_config = 2 + [(validate.rules).message = {required: true}]; +} + +// Extra settings on a per virtualhost/route/weighted-cluster level. +message ExtensionWithMatcherPerRoute { + // Matcher override. + xds.type.matcher.v3.Matcher xds_matcher = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/composite/v3/composite.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/composite/v3/composite.proto new file mode 100644 index 00000000000..1ab6c5eb1ef --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/composite/v3/composite.proto @@ -0,0 +1,106 @@ +syntax = "proto3"; + +package envoy.extensions.filters.http.composite.v3; + +import "envoy/config/core/v3/base.proto"; +import "envoy/config/core/v3/config_source.proto"; +import "envoy/config/core/v3/extension.proto"; + +import "udpa/annotations/migrate.proto"; +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.filters.http.composite.v3"; +option java_outer_classname = "CompositeProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/composite/v3;compositev3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Composite] +// Composite Filter :ref:`configuration overview `. +// [#extension: envoy.filters.http.composite] + +// :ref:`Composite filter ` config. The composite filter config +// allows delegating filter handling to another filter as determined by matching on the request +// headers. This makes it possible to use different filters or filter configurations based on the +// incoming request. +// +// This is intended to be used with +// :ref:`ExtensionWithMatcher ` +// where a match tree is specified that indicates (via +// :ref:`ExecuteFilterAction `) +// which filter configuration to create and delegate to. +message Composite { + // Named filter chain definitions that can be referenced from + // :ref:`ExecuteFilterAction.filter_chain_name + // `. + // The filter chains are compiled at configuration time and can be referenced by name. + // This is useful when the same filter chain needs to be applied across many routes, + // as it avoids duplicating the filter chain configuration. + map named_filter_chains = 1; +} + +// A list of filter configurations to be called in order. Note that this can be used as the type +// inside of an ECDS :ref:`TypedExtensionConfig +// ` extension, which allows a chain of +// filters to be configured dynamically. In that case, the types of all filters in the chain must +// be present in the :ref:`ExtensionConfigSource.type_urls +// ` field. +message FilterChainConfiguration { + repeated config.core.v3.TypedExtensionConfig typed_config = 1; +} + +// Configuration for an extension configuration discovery service with name. +message DynamicConfig { + // The name of the extension configuration. It also serves as a resource name in ExtensionConfigDS. + // The resource type in the ``DiscoveryRequest`` will be :ref:`TypedExtensionConfig + // `. + string name = 1 [(validate.rules).string = {min_len: 1}]; + + // Configuration source specifier for an extension configuration discovery + // service. In case of a failure and without the default configuration, + // 500(Internal Server Error) will be returned. + config.core.v3.ExtensionConfigSource config_discovery = 2; +} + +// Composite match action (see :ref:`matching docs ` for more info on match actions). +// This specifies the filter configuration of the filter that the composite filter should delegate filter interactions to. +// [#next-free-field: 6] +message ExecuteFilterAction { + // Filter specific configuration which depends on the filter being + // instantiated. See the supported filters for further documentation. + // Only one of ``typed_config``, ``dynamic_config``, ``filter_chain``, or ``filter_chain_name`` + // can be set. + // [#extension-category: envoy.filters.http] + config.core.v3.TypedExtensionConfig typed_config = 1 + [(udpa.annotations.field_migrate).oneof_promotion = "config_type"]; + + // Dynamic configuration of filter obtained via extension configuration discovery service. + // Only one of ``typed_config``, ``dynamic_config``, ``filter_chain``, or ``filter_chain_name`` + // can be set. + DynamicConfig dynamic_config = 2 + [(udpa.annotations.field_migrate).oneof_promotion = "config_type"]; + + // An inlined list of filter configurations. The specified filters will be executed in order. + // Only one of ``typed_config``, ``dynamic_config``, ``filter_chain``, or ``filter_chain_name`` + // can be set. + FilterChainConfiguration filter_chain = 4; + + // The name of a filter chain defined in + // :ref:`Composite.named_filter_chains + // `. + // At runtime, if the named filter chain is not found in the Composite filter's configuration, + // no filter will be applied for this match (the action is silently skipped). + // Only one of ``typed_config``, ``dynamic_config``, ``filter_chain``, or ``filter_chain_name`` + // can be set. + string filter_chain_name = 5; + + // Probability of the action execution. If not specified, this is 100%. + // This allows sampling behavior for the configured actions. + // For example, if + // :ref:`default_value ` + // under the ``sample_percent`` is configured with 30%, a dice roll with that + // probability is done. The underline action will only be executed if the + // dice roll returns positive. Otherwise, the action is skipped. + config.core.v3.RuntimeFractionalPercent sample_percent = 3; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/ext_authz/v3/ext_authz.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/ext_authz/v3/ext_authz.proto new file mode 100644 index 00000000000..7f70b70013b --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/ext_authz/v3/ext_authz.proto @@ -0,0 +1,602 @@ +syntax = "proto3"; + +package envoy.extensions.filters.http.ext_authz.v3; + +import "envoy/config/common/mutation_rules/v3/mutation_rules.proto"; +import "envoy/config/core/v3/base.proto"; +import "envoy/config/core/v3/config_source.proto"; +import "envoy/config/core/v3/grpc_service.proto"; +import "envoy/config/core/v3/http_uri.proto"; +import "envoy/type/matcher/v3/metadata.proto"; +import "envoy/type/matcher/v3/string.proto"; +import "envoy/type/v3/http_status.proto"; + +import "google/protobuf/struct.proto"; +import "google/protobuf/wrappers.proto"; + +import "envoy/annotations/deprecation.proto"; +import "udpa/annotations/sensitive.proto"; +import "udpa/annotations/status.proto"; +import "udpa/annotations/versioning.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3"; +option java_outer_classname = "ExtAuthzProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_authz/v3;ext_authzv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: External Authorization] +// External Authorization :ref:`configuration overview `. +// [#extension: envoy.filters.http.ext_authz] + +// [#next-free-field: 32] +message ExtAuthz { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v3.ExtAuthz"; + + reserved 4; + + reserved "use_alpha"; + + // External authorization service configuration. + oneof services { + // gRPC service configuration (default timeout: 200ms). + config.core.v3.GrpcService grpc_service = 1; + + // HTTP service configuration (default timeout: 200ms). + HttpService http_service = 3; + } + + // API version for ext_authz transport protocol. This describes the ext_authz gRPC endpoint and + // version of messages used on the wire. + config.core.v3.ApiVersion transport_api_version = 12 + [(validate.rules).enum = {defined_only: true}]; + + // Changes the filter's behavior on errors: + // + // * When set to ``true``, the filter will ``accept`` the client request even if communication with + // the authorization service has failed, or if the authorization service has returned an HTTP 5xx + // error. + // + // * When set to ``false``, the filter will ``reject`` client requests and return ``Forbidden`` + // if communication with the authorization service has failed, or if the authorization service + // has returned an HTTP 5xx error. + // + // Errors can always be tracked in the :ref:`stats `. + // + // Defaults to ``false``. + bool failure_mode_allow = 2; + + // When ``failure_mode_allow`` and ``failure_mode_allow_header_add`` are both set to ``true``, + // ``x-envoy-auth-failure-mode-allowed: true`` will be added to request headers if the communication + // with the authorization service has failed, or if the authorization service has returned a + // HTTP 5xx error. + bool failure_mode_allow_header_add = 19; + + // Enables the filter to buffer the client request body and send it within the authorization request. + // The ``x-envoy-auth-partial-body: false|true`` metadata header will be added to the authorization + // request indicating whether the body data is partial. + BufferSettings with_request_body = 5; + + // Clears the route cache in order to allow the external authorization service to correctly affect + // routing decisions. The filter clears all cached routes when all of the following holds: + // + // * This field is set to ``true``. + // * The status returned from the authorization service is an HTTP 200 or gRPC 0. + // * At least one ``authorization response header`` is added to the client request, or is used to + // alter another client request header. + // + // Defaults to ``false``. + bool clear_route_cache = 6; + + // Sets the HTTP status that is returned to the client when the authorization server returns an error + // or cannot be reached. + // + // The default status is ``HTTP 403 Forbidden``. + type.v3.HttpStatus status_on_error = 7; + + // When set to ``true``, the filter will check the :ref:`ext_authz response + // ` for invalid header and + // query parameter mutations. If the response is invalid, the filter will send a local reply + // to the downstream request with status ``HTTP 500 Internal Server Error``. + // + // .. note:: + // Both ``headers_to_remove`` and ``query_parameters_to_remove`` are validated, but invalid elements in + // those fields should not affect any headers and thus will not cause the filter to send a local reply. + // + // When set to ``false``, any invalid mutations will be visible to the rest of Envoy and may cause + // unexpected behavior. + // + // If you are using ext_authz with an untrusted ext_authz server, you should set this to ``true``. + // + // Defaults to ``false``. + bool validate_mutations = 24; + + // Specifies a list of metadata namespaces whose values, if present, will be passed to the + // ext_authz service. The :ref:`filter_metadata ` + // is passed as an opaque ``protobuf::Struct``. + // + // .. note:: + // This field applies exclusively to the gRPC ext_authz service and has no effect on the HTTP service. + // + // For example, if the ``jwt_authn`` filter is used and :ref:`payload_in_metadata + // ` is set, + // then the following will pass the jwt payload to the authorization server. + // + // .. code-block:: yaml + // + // metadata_context_namespaces: + // - envoy.filters.http.jwt_authn + // + repeated string metadata_context_namespaces = 8; + + // Specifies a list of metadata namespaces whose values, if present, will be passed to the + // ext_authz service. :ref:`typed_filter_metadata ` + // is passed as a ``protobuf::Any``. + // + // .. note:: + // This field applies exclusively to the gRPC ext_authz service and has no effect on the HTTP service. + // + // This works similarly to ``metadata_context_namespaces`` but allows Envoy and the ext_authz server to share + // the protobuf message definition in order to perform safe parsing. + // + repeated string typed_metadata_context_namespaces = 16; + + // Specifies a list of route metadata namespaces whose values, if present, will be passed to the + // ext_authz service at :ref:`route_metadata_context ` in + // :ref:`CheckRequest `. + // :ref:`filter_metadata ` is passed as an opaque ``protobuf::Struct``. + repeated string route_metadata_context_namespaces = 21; + + // Specifies a list of route metadata namespaces whose values, if present, will be passed to the + // ext_authz service at :ref:`route_metadata_context ` in + // :ref:`CheckRequest `. + // :ref:`typed_filter_metadata ` is passed as a ``protobuf::Any``. + repeated string route_typed_metadata_context_namespaces = 22; + + // Specifies if the filter is enabled. + // + // If :ref:`runtime_key ` is specified, + // Envoy will lookup the runtime key to get the percentage of requests to filter. + // + // If this field is not specified, the filter will be enabled for all requests. + config.core.v3.RuntimeFractionalPercent filter_enabled = 9; + + // Specifies if the filter is enabled with metadata matcher. + // If this field is not specified, the filter will be enabled for all requests. + // + // .. note:: + // + // This field is only evaluated if the filter is instantiated. If the filter is marked with + // ``disabled: true`` in the :ref:`HttpFilter + // ` + // configuration or in per-route configuration via :ref:`ExtAuthzPerRoute + // `, + // the filter will not be instantiated and this field will have no effect. + // + // .. tip:: + // + // For dynamic filter activation based on metadata (such as metadata set by a preceding + // filter), consider using :ref:`ExtensionWithMatcher + // ` instead. This + // provides a more flexible matching framework that can evaluate conditions before filter + // instantiation. See the :ref:`ext_authz filter documentation + // ` for examples. + type.matcher.v3.MetadataMatcher filter_enabled_metadata = 14; + + // Specifies whether to deny the requests when the filter is disabled. + // If :ref:`runtime_key ` is specified, + // Envoy will lookup the runtime key to determine whether to deny requests for filter-protected paths + // when the filter is disabled. If the filter is disabled in ``typed_per_filter_config`` for the path, + // requests will not be denied. + // + // If this field is not specified, all requests will be allowed when disabled. + // + // If a request is denied due to this setting, the response code in :ref:`status_on_error + // ` will + // be returned. + config.core.v3.RuntimeFeatureFlag deny_at_disable = 11; + + // Specifies if the peer certificate is sent to the external service. + // + // When this field is ``true``, Envoy will include the peer X.509 certificate, if available, in the + // :ref:`certificate`. + bool include_peer_certificate = 10; + + // Optional additional prefix to use when emitting statistics. This allows distinguishing + // emitted statistics between configured ``ext_authz`` filters in an HTTP filter chain. For example: + // + // .. code-block:: yaml + // + // http_filters: + // - name: envoy.filters.http.ext_authz + // typed_config: + // "@type": type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthz + // stat_prefix: waf # This emits ext_authz.waf.ok, ext_authz.waf.denied, etc. + // - name: envoy.filters.http.ext_authz + // typed_config: + // "@type": type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthz + // stat_prefix: blocker # This emits ext_authz.blocker.ok, ext_authz.blocker.denied, etc. + // + string stat_prefix = 13; + + // Optional labels that will be passed to :ref:`labels` in + // :ref:`destination`. + // The labels will be read from :ref:`metadata` with the specified key. + string bootstrap_metadata_labels_key = 15; + + // Check request to authorization server will include the client request headers that have a correspondent match + // in the list. If this option isn't specified, then + // all client request headers are included in the check request to a gRPC authorization server, whereas no client request headers + // (besides the ones allowed by default - see note below) are included in the check request to an HTTP authorization server. + // This inconsistency between gRPC and HTTP servers is to maintain backwards compatibility with legacy behavior. + // + // .. note:: + // + // For requests to an HTTP authorization server: in addition to the user's supplied matchers, ``Host``, ``Method``, ``Path``, + // ``Content-Length``, and ``Authorization`` are **additionally included** in the list. + // + // .. note:: + // + // For requests to an HTTP authorization server: the value of ``Content-Length`` will be set to ``0`` and the request to the + // authorization server will not have a message body. However, the check request can include the buffered + // client request body (controlled by :ref:`with_request_body + // ` setting); + // consequently, the value of ``Content-Length`` in the authorization request reflects the size of its payload. + // + // .. note:: + // + // This can be overridden by the field ``disallowed_headers`` below. That is, if a header + // matches for both ``allowed_headers`` and ``disallowed_headers``, the header will NOT be sent. + type.matcher.v3.ListStringMatcher allowed_headers = 17; + + // If set, specifically disallow any header in this list to be forwarded to the external + // authentication server. This overrides the above ``allowed_headers`` if a header matches both. + type.matcher.v3.ListStringMatcher disallowed_headers = 25; + + // Specifies if the TLS session level details like SNI are sent to the external service. + // + // When this field is ``true``, Envoy will include the SNI name used for TLSClientHello, if available, in the + // :ref:`tls_session`. + bool include_tls_session = 18; + + // Whether to increment cluster statistics (e.g. cluster..upstream_rq_*) on authorization failure. + // Defaults to ``true``. + google.protobuf.BoolValue charge_cluster_response_stats = 20; + + // Whether to encode the raw headers (i.e., unsanitized values and unconcatenated multi-line headers) + // in the authorization request. Works with both HTTP and gRPC clients. + // + // When this is set to ``true``, header values are not sanitized. Headers with the same key will also + // not be combined into a single, comma-separated header. + // Requests to gRPC services will populate the field + // :ref:`header_map`. + // Requests to HTTP services will be constructed with the unsanitized header values and preserved + // multi-line headers with the same key. + // + // If this field is set to ``false``, header values will be sanitized, with any non-UTF-8-compliant + // bytes replaced with ``'!'``. Headers with the same key will have their values concatenated into a + // single comma-separated header value. + // Requests to gRPC services will populate the field + // :ref:`headers`. + // Requests to HTTP services will have their header values sanitized and will not preserve + // multi-line headers with the same key. + // + // It is recommended to set this to ``true`` unless you rely on the previous behavior. + // + // It is set to ``false`` by default for backwards compatibility. + bool encode_raw_headers = 23; + + // Rules for what modifications an ext_authz server may make to the request headers before + // continuing decoding or forwarding upstream. + // + // If set, enables header mutation checking against the configured rules. Note that + // :ref:`HeaderMutationRules ` + // has defaults that change ext_authz behavior. Also note that if this field is set, + // ext_authz can no longer append to ``:``-prefixed headers. + // + // If unset, header mutation rule checking is completely disabled. + // + // Regardless of what is configured here, ext_authz cannot remove ``:``-prefixed headers. + // + // This field and ``validate_mutations`` have different use cases. ``validate_mutations`` enables + // correctness checks for all header and query parameter mutations (for example, invalid characters). + // This field allows the filter to reject mutations to specific headers. + config.common.mutation_rules.v3.HeaderMutationRules decoder_header_mutation_rules = 26; + + // Enable or disable ingestion of dynamic metadata from the ext_authz service. + // + // If ``false``, the filter will ignore dynamic metadata injected by the ext_authz service. If the + // ext_authz service tries injecting dynamic metadata, the filter will log, increment the + // ``ignored_dynamic_metadata`` stat, then continue handling the response. + // + // If ``true``, the filter will ingest dynamic metadata entries as normal. + // + // If unset, defaults to ``true``. + google.protobuf.BoolValue enable_dynamic_metadata_ingestion = 27; + + // Additional metadata to be added to the filter state for logging purposes. The metadata will be + // added to StreamInfo's filter state under the namespace corresponding to the ext_authz filter + // name. + google.protobuf.Struct filter_metadata = 28; + + // When set to ``true``, the filter will emit per-stream stats for access logging. The filter state + // key will be the same as the filter name. + // + // If using Envoy gRPC, emits latency, bytes sent / received, upstream info, and upstream cluster + // info. If not using Envoy gRPC, emits only latency. + // + // .. note:: + // Stats are ONLY added to filter state if a check request is actually made to an ext_authz service. + // + // If this is ``false`` the filter will not emit stats, but filter_metadata will still be respected if + // it has a value. + // + // Field ``latency_us`` is exposed for CEL and logging when using gRPC or HTTP service. + // Fields ``bytesSent`` and ``bytesReceived`` are exposed for CEL and logging only when using gRPC service. + bool emit_filter_state_stats = 29; + + // Sets the maximum size (in bytes) of the response body that the filter will send downstream + // when a request is denied by the external authorization service. + // + // If the authorization server returns a response body larger than this configured limit, + // the body will be truncated to ``max_denied_response_body_bytes`` before being sent to the + // downstream client. + // + // If this field is not set or is set to 0, no truncation will occur, and the entire + // denied response body will be forwarded. + uint32 max_denied_response_body_bytes = 30; + + // When set to ``true``, the filter will enforce the response header map's count and size limits + // by sending a local reply when those limits are violated. + // + // When set to ``false``, the filter will ignore the response header map's limits and add / set + // all response headers as specified by the external authorization service. + // + // Recommendation: enable if the external authorization service is not trusted. Otherwise, leave + // it ``false``. + // + // Defaults to ``false``. + bool enforce_response_header_limits = 31; +} + +// Configuration for buffering the request data. +message BufferSettings { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v2.BufferSettings"; + + // Sets the maximum size of a message body that the filter will hold in memory. Envoy will return + // ``HTTP 413`` and will *not* initiate the authorization process when the buffer reaches the size + // set in this field. + // + // .. note:: + // This setting will have precedence over :ref:`failure_mode_allow + // `. + uint32 max_request_bytes = 1 [(validate.rules).uint32 = {gt: 0}]; + + // When this field is ``true``, Envoy will buffer the message until ``max_request_bytes`` is reached. + // The authorization request will be dispatched and no 413 HTTP error will be returned by the + // filter. + // + // Defaults to ``false``. + bool allow_partial_message = 2; + + // If ``true``, the body sent to the external authorization service is set as raw bytes and populates + // :ref:`raw_body` + // in the HTTP request attribute context. Otherwise, :ref:`body + // ` will be populated + // with a UTF-8 string request body. + // + // This field only affects configurations using a :ref:`grpc_service + // `. In configurations that use + // an :ref:`http_service `, this + // has no effect. + // + // Defaults to ``false``. + bool pack_as_bytes = 3; +} + +// HttpService is used for raw HTTP communication between the filter and the authorization service. +// When configured, the filter will parse the client request and use these attributes to call the +// authorization server. Depending on the response, the filter may reject or accept the client +// request. +// +// .. note:: +// In any of these events, metadata can be added, removed or overridden by the filter: +// +// On authorization request, a list of allowed request headers may be supplied. See +// :ref:`allowed_headers +// ` +// for details. Additional headers metadata may be added to the authorization request. See +// :ref:`headers_to_add +// ` for +// details. +// +// On authorization response status ``HTTP 200 OK``, the filter will allow traffic to the upstream and +// additional headers metadata may be added to the original client request. See +// :ref:`allowed_upstream_headers +// ` +// for details. Additionally, the filter may add additional headers to the client's response. See +// :ref:`allowed_client_headers_on_success +// ` +// for details. +// +// On other authorization response statuses, the filter will not allow traffic. Additional headers +// metadata as well as body may be added to the client's response. See :ref:`allowed_client_headers +// ` +// for details. +// [#next-free-field: 10] +message HttpService { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v2.HttpService"; + + reserved 3, 4, 5, 6; + + // Sets the HTTP server URI which the authorization requests must be sent to. + config.core.v3.HttpUri server_uri = 1; + + // Sets a prefix to the value of authorization request header ``Path``. + string path_prefix = 2; + + // Settings used for controlling authorization request metadata. + AuthorizationRequest authorization_request = 7; + + // Settings used for controlling authorization response metadata. + AuthorizationResponse authorization_response = 8; + + // Optional retry policy for requests to the authorization server. + // If not set, no retries will be performed. + // + // .. note:: + // When this field is set, the ``ext_authz`` filter will buffer the request body for retry purposes. + config.core.v3.RetryPolicy retry_policy = 9; +} + +message AuthorizationRequest { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v2.AuthorizationRequest"; + + // Authorization request includes the client request headers that have a corresponding match + // in the list. + // This field has been deprecated in favor of :ref:`allowed_headers + // `. + // + // .. note:: + // + // In addition to the user's supplied matchers, ``Host``, ``Method``, ``Path``, + // ``Content-Length``, and ``Authorization`` are **automatically included** in the list. + // + // .. note:: + // + // By default, the ``Content-Length`` header is set to ``0`` and the request to the authorization + // service has no message body. However, the authorization request *may* include the buffered + // client request body (controlled by :ref:`with_request_body + // ` + // setting); hence the value of its ``Content-Length`` reflects the size of its payload. + // + type.matcher.v3.ListStringMatcher allowed_headers = 1 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // Sets a list of headers that will be included in the request to the authorization service. + // + // .. note:: + // Client request headers with the same key will be overridden. + repeated config.core.v3.HeaderValue headers_to_add = 2; +} + +// [#next-free-field: 6] +message AuthorizationResponse { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v2.AuthorizationResponse"; + + // When this list is set, authorization + // response headers that have a correspondent match will be added to the original client request. + // + // .. note:: + // Existing headers will be overridden. + type.matcher.v3.ListStringMatcher allowed_upstream_headers = 1; + + // When this list is set, authorization + // response headers that have a correspondent match will be added to the original client request. + // + // .. note:: + // Existing headers will be appended. + type.matcher.v3.ListStringMatcher allowed_upstream_headers_to_append = 3; + + // When this list is set, authorization + // response headers that have a correspondent match will be added to the client's response. + // When a header is included in this list, ``Path``, ``Status``, ``Content-Length``, ``WWW-Authenticate`` and + // ``Location`` are automatically added. + // + // .. note:: + // When this list is *not* set, all the authorization response headers, except + // ``Authority (Host)``, will be in the response to the client. + type.matcher.v3.ListStringMatcher allowed_client_headers = 2; + + // When this list is set, authorization + // response headers that have a correspondent match will be added to the client's response when + // the authorization response itself is successful, i.e. not failed or denied. When this list is + // *not* set, no additional headers will be added to the client's response on success. + type.matcher.v3.ListStringMatcher allowed_client_headers_on_success = 4; + + // When this list is set, authorization + // response headers that have a correspondent match will be emitted as dynamic metadata to be consumed + // by the next filter. This metadata lives in a namespace specified by the canonical name of extension filter + // that requires it: + // + // - :ref:`envoy.filters.http.ext_authz ` for HTTP filter. + // - :ref:`envoy.filters.network.ext_authz ` for network filter. + type.matcher.v3.ListStringMatcher dynamic_metadata_from_headers = 5; +} + +// Extra settings on a per virtualhost/route/weighted-cluster level. +message ExtAuthzPerRoute { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v2.ExtAuthzPerRoute"; + + oneof override { + option (validate.required) = true; + + // Disable the ext auth filter for this particular vhost or route. + // If disabled is specified in multiple per-filter-configs, the most specific one will be used. + // If the filter is disabled by default and this is set to ``false``, the filter will be enabled + // for this vhost or route. + bool disabled = 1; + + // Check request settings for this route. + CheckSettings check_settings = 2 [(validate.rules).message = {required: true}]; + } +} + +// Extra settings for the check request. +// [#next-free-field: 6] +message CheckSettings { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v2.CheckSettings"; + + // Context extensions to set on the CheckRequest's + // :ref:`AttributeContext.context_extensions` + // + // You can use this to provide extra context for the external authorization server on specific + // virtual hosts/routes. For example, adding a context extension on the virtual host level can + // give the ext-authz server information on what virtual host is used without needing to parse the + // host header. If CheckSettings is specified in multiple per-filter-configs, they will be merged + // in order, and the result will be used. + // + // Merge semantics for this field are such that keys from more specific configs override. + // + // .. note:: + // These settings are only applied to a filter configured with a + // :ref:`grpc_service`. + map context_extensions = 1 [(udpa.annotations.sensitive) = true]; + + // When set to ``true``, disable the configured :ref:`with_request_body + // ` for a specific route. + // + // Only one of ``disable_request_body_buffering`` and + // :ref:`with_request_body ` + // may be specified. + bool disable_request_body_buffering = 2; + + // Enable or override request body buffering, which is configured using the + // :ref:`with_request_body ` + // option for a specific route. + // + // Only one of ``with_request_body`` and + // :ref:`disable_request_body_buffering ` + // may be specified. + BufferSettings with_request_body = 3; + + // Override the external authorization service for this route. + // This allows different routes to use different external authorization service backends + // and service types (gRPC or HTTP). If specified, this overrides the filter-level service + // configuration regardless of the original service type. + oneof service_override { + // Override with a gRPC service configuration. + config.core.v3.GrpcService grpc_service = 4; + + // Override with an HTTP service configuration. + HttpService http_service = 5; + } +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/gcp_authn/v3/gcp_authn.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/gcp_authn/v3/gcp_authn.proto new file mode 100644 index 00000000000..f4646389f7e --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/gcp_authn/v3/gcp_authn.proto @@ -0,0 +1,87 @@ +syntax = "proto3"; + +package envoy.extensions.filters.http.gcp_authn.v3; + +import "envoy/config/core/v3/base.proto"; +import "envoy/config/core/v3/http_uri.proto"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/wrappers.proto"; + +import "envoy/annotations/deprecation.proto"; +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3"; +option java_outer_classname = "GcpAuthnProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/gcp_authn/v3;gcp_authnv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: GCP authentication] +// GCP authentication :ref:`configuration overview `. +// [#extension: envoy.filters.http.gcp_authn] + +// Filter configuration. +// [#next-free-field: 7] +message GcpAuthnFilterConfig { + // The HTTP URI to fetch tokens from GCE Metadata Server(https://cloud.google.com/compute/docs/metadata/overview). + // The URL format is "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity?audience=[AUDIENCE]" + // + // This field is deprecated because it does not match the API surface provided by the google auth libraries. + // Control planes should not attempt to override the metadata server URI. + // The cluster and timeout can be configured using the ``cluster`` and ``timeout`` fields instead. + // For backward compatibility, the cluster and timeout configured in this field will be used + // if the new ``cluster`` and ``timeout`` fields are not set. + config.core.v3.HttpUri http_uri = 1 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // Retry policy for fetching tokens. + // Not supported by all data planes. + config.core.v3.RetryPolicy retry_policy = 2; + + // Token cache configuration. This field is optional. + TokenCacheConfig cache_config = 3; + + // Request header location to extract the token. By default (i.e. if this field is not specified), the token + // is extracted to the Authorization HTTP header, in the format "Authorization: Bearer ". + // Not supported by all data planes. + TokenHeader token_header = 4; + + // Cluster to send traffic to the GCE metadata server. Not supported + // by all data planes; a data plane may instead have its own mechanism + // for contacting the metadata server. + string cluster = 5; + + // Timeout for fetching the tokens from the GCE metadata server. + // Not supported by all data planes. + google.protobuf.Duration timeout = 6 [(validate.rules).duration = { + lt {seconds: 4294967296} + gte {} + }]; +} + +// Audience is the URL of the receiving service that performs token authentication. +// It will be provided to the filter through cluster's typed_filter_metadata. +message Audience { + string url = 1 [(validate.rules).string = {min_len: 1}]; +} + +// Token Cache configuration. +message TokenCacheConfig { + // The number of cache entries. The maximum number of entries is INT64_MAX as it is constrained by underlying cache implementation. + // Default value 0 (i.e., proto3 defaults) disables the cache by default. Other default values will enable the cache. + google.protobuf.UInt64Value cache_size = 1 [(validate.rules).uint64 = {lte: 9223372036854775807}]; +} + +message TokenHeader { + // The HTTP header's name. + string name = 1 + [(validate.rules).string = {min_len: 1 well_known_regex: HTTP_HEADER_NAME strict: false}]; + + // The header's prefix. The format is "value_prefix" + // For example, for "Authorization: Bearer ", value_prefix="Bearer " with a space at the + // end. + string value_prefix = 2 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto index 649869a255d..a37efe157db 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto @@ -4,7 +4,6 @@ package envoy.extensions.filters.http.rbac.v3; import "envoy/config/rbac/v3/rbac.proto"; -import "xds/annotations/v3/status.proto"; import "xds/type/matcher/v3/matcher.proto"; import "udpa/annotations/migrate.proto"; @@ -27,48 +26,51 @@ message RBAC { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.http.rbac.v2.RBAC"; - // Specify the RBAC rules to be applied globally. - // If absent, no enforcing RBAC policy will be applied. - // If present and empty, DENY. - // If both rules and matcher are configured, rules will be ignored. + // The primary RBAC policy which will be applied globally, to all the incoming requests. + // + // * If absent, no RBAC enforcement occurs. + // * If set but empty, all requests are denied. + // + // .. note:: + // + // When both ``rules`` and ``matcher`` are configured, ``rules`` will be ignored. + // config.rbac.v3.RBAC rules = 1 [(udpa.annotations.field_migrate).oneof_promotion = "rules_specifier"]; // If specified, rules will emit stats with the given prefix. - // This is useful to distinguish the stat when there are more than 1 RBAC filter configured with - // rules. + // This is useful for distinguishing metrics when multiple RBAC filters are configured. string rules_stat_prefix = 6; - // The match tree to use when resolving RBAC action for incoming requests. Requests do not - // match any matcher will be denied. - // If absent, no enforcing RBAC matcher will be applied. - // If present and empty, deny all requests. - xds.type.matcher.v3.Matcher matcher = 4 [ - (udpa.annotations.field_migrate).oneof_promotion = "rules_specifier", - (xds.annotations.v3.field_status).work_in_progress = true - ]; - - // Shadow rules are not enforced by the filter (i.e., returning a 403) - // but will emit stats and logs and can be used for rule testing. - // If absent, no shadow RBAC policy will be applied. - // If both shadow rules and shadow matcher are configured, shadow rules will be ignored. + // Match tree for evaluating RBAC actions on incoming requests. Requests not matching any matcher will be denied. + // + // * If absent, no RBAC enforcement occurs. + // * If set but empty, all requests are denied. + // + xds.type.matcher.v3.Matcher matcher = 4 + [(udpa.annotations.field_migrate).oneof_promotion = "rules_specifier"]; + + // Shadow policy for testing RBAC rules without enforcing them. These rules generate stats and logs but do not deny + // requests. If absent, no shadow RBAC policy will be applied. + // + // .. note:: + // + // When both ``shadow_rules`` and ``shadow_matcher`` are configured, ``shadow_rules`` will be ignored. + // config.rbac.v3.RBAC shadow_rules = 2 [(udpa.annotations.field_migrate).oneof_promotion = "shadow_rules_specifier"]; - // The match tree to use for emitting stats and logs which can be used for rule testing for - // incoming requests. // If absent, no shadow matcher will be applied. - xds.type.matcher.v3.Matcher shadow_matcher = 5 [ - (udpa.annotations.field_migrate).oneof_promotion = "shadow_rules_specifier", - (xds.annotations.v3.field_status).work_in_progress = true - ]; + // Match tree for testing RBAC rules through stats and logs without enforcing them. + // If absent, no shadow matching occurs. + xds.type.matcher.v3.Matcher shadow_matcher = 5 + [(udpa.annotations.field_migrate).oneof_promotion = "shadow_rules_specifier"]; // If specified, shadow rules will emit stats with the given prefix. - // This is useful to distinguish the stat when there are more than 1 RBAC filter configured with - // shadow rules. + // This is useful for distinguishing metrics when multiple RBAC filters use shadow rules. string shadow_rules_stat_prefix = 3; - // If track_per_rule_stats is true, counters will be published for each rule and shadow rule. + // If ``track_per_rule_stats`` is ``true``, counters will be published for each rule and shadow rule. bool track_per_rule_stats = 7; } @@ -78,7 +80,7 @@ message RBACPerRoute { reserved 1; - // Override the global configuration of the filter with this new config. - // If absent, the global RBAC policy will be disabled for this route. + // Per-route specific RBAC configuration that overrides the global RBAC configuration. + // If absent, RBAC policy will be disabled for this route. RBAC rbac = 2; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/router/v3/router.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/router/v3/router.proto index 75bca960da1..7da658bcb33 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/router/v3/router.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/router/v3/router.proto @@ -23,7 +23,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Router :ref:`configuration overview `. // [#extension: envoy.filters.http.router] -// [#next-free-field: 10] +// [#next-free-field: 11] message Router { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.http.router.v2.Router"; @@ -119,11 +119,11 @@ message Router { // for more details. bool suppress_grpc_request_failure_code_stats = 7; + // Optional HTTP filters for the upstream HTTP filter chain. + // // .. note:: // Upstream HTTP filters are currently in alpha. // - // Optional HTTP filters for the upstream HTTP filter chain. - // // These filters will be applied for all requests that pass through the router. // They will also be applied to shadowed requests. // Upstream HTTP filters cannot change route or cluster. @@ -134,4 +134,10 @@ message Router { // upstream HTTP filters will count as a final response if hedging is configured. // [#extension-category: envoy.filters.http.upstream] repeated network.http_connection_manager.v3.HttpFilter upstream_http_filters = 8; + + // If set to true, Envoy will reject ``CONNECT`` requests that send data before + // receiving a ``200`` response from the upstream. This early data behavior + // is common for latency reduction but can cause issues with some upstreams. + // Defaults to false to allow early data and be compatible with common behavior. + google.protobuf.BoolValue reject_connect_request_early_data = 10; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto index 9e7274daa53..9d8cf8bf4fd 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto @@ -20,6 +20,8 @@ import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; +import "xds/type/matcher/v3/matcher.proto"; + import "envoy/annotations/deprecation.proto"; import "udpa/annotations/migrate.proto"; import "udpa/annotations/security.proto"; @@ -37,7 +39,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // HTTP connection manager :ref:`configuration overview `. // [#extension: envoy.filters.network.http_connection_manager] -// [#next-free-field: 58] +// [#next-free-field: 61] message HttpConnectionManager { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.http_connection_manager.v2.HttpConnectionManager"; @@ -58,9 +60,8 @@ message HttpConnectionManager { // Prior knowledge is allowed). HTTP2 = 2; - // [#not-implemented-hide:] QUIC implementation is not production ready yet. Use this enum with - // caution to prevent accidental execution of QUIC code. I.e. `!= HTTP2` is no longer sufficient - // to distinguish HTTP1 and HTTP2 traffic. + // The connection manager will assume that the client is speaking HTTP/3. + // This needs to be consistent with listener and transport socket config. HTTP3 = 3; } @@ -100,41 +101,53 @@ message HttpConnectionManager { ALWAYS_FORWARD_ONLY = 4; } - // Determines the action for request that contain %2F, %2f, %5C or %5c sequences in the URI path. + // Determines the action for request that contain ``%2F``, ``%2f``, ``%5C`` or ``%5c`` sequences in the URI path. // This operation occurs before URL normalization and the merge slashes transformations if they were enabled. enum PathWithEscapedSlashesAction { // Default behavior specific to implementation (i.e. Envoy) of this configuration option. // Envoy, by default, takes the KEEP_UNCHANGED action. - // NOTE: the implementation may change the default behavior at-will. + // + // .. note:: + // + // The implementation may change the default behavior at-will. IMPLEMENTATION_SPECIFIC_DEFAULT = 0; // Keep escaped slashes. KEEP_UNCHANGED = 1; // Reject client request with the 400 status. gRPC requests will be rejected with the INTERNAL (13) error code. - // The "httpN.downstream_rq_failed_path_normalization" counter is incremented for each rejected request. + // The ``httpN.downstream_rq_failed_path_normalization`` counter is incremented for each rejected request. REJECT_REQUEST = 2; - // Unescape %2F and %5C sequences and redirect request to the new path if these sequences were present. + // Unescape ``%2F`` and ``%5C`` sequences and redirect request to the new path if these sequences were present. // Redirect occurs after path normalization and merge slashes transformations if they were configured. - // NOTE: gRPC requests will be rejected with the INTERNAL (13) error code. - // This option minimizes possibility of path confusion exploits by forcing request with unescaped slashes to - // traverse all parties: downstream client, intermediate proxies, Envoy and upstream server. - // The "httpN.downstream_rq_redirected_with_normalized_path" counter is incremented for each - // redirected request. + // + // .. note:: + // + // gRPC requests will be rejected with the INTERNAL (13) error code. This option minimizes possibility of path + // confusion exploits by forcing request with unescaped slashes to traverse all parties: downstream client, + // intermediate proxies, Envoy and upstream server. The ``httpN.downstream_rq_redirected_with_normalized_path`` + // counter is incremented for each redirected request. + // UNESCAPE_AND_REDIRECT = 3; - // Unescape %2F and %5C sequences. - // Note: this option should not be enabled if intermediaries perform path based access control as - // it may lead to path confusion vulnerabilities. + // Unescape ``%2F`` and ``%5C`` sequences. + // + // .. note:: + // + // This option should not be enabled if intermediaries perform path based access control as it may lead to path + // confusion vulnerabilities. + // UNESCAPE_AND_FORWARD = 4; } - // [#next-free-field: 11] + // [#next-free-field: 13] message Tracing { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.http_connection_manager.v2.HttpConnectionManager.Tracing"; + // This OperationName makes no sense and is unnecessary in the current tracing API. + // [#not-implemented-hide:] enum OperationName { // The HTTP listener is used for ingress/incoming requests. INGRESS = 0; @@ -186,14 +199,6 @@ message HttpConnectionManager { // Configuration for an external tracing provider. // If not specified, no tracing will be performed. - // - // .. attention:: - // Please be aware that ``envoy.tracers.opencensus`` provider can only be configured once - // in Envoy lifetime. - // Any attempts to reconfigure it or to use different configurations for different HCM filters - // will be rejected. - // Such a constraint is inherent to OpenCensus itself. It cannot be overcome without changes - // on OpenCensus side. config.trace.v3.Tracing.Http provider = 9; // Create separate tracing span for each upstream request if true. And if this flag is set to true, @@ -216,6 +221,28 @@ message HttpConnectionManager { // // The default value is false for now for backward compatibility. google.protobuf.BoolValue spawn_upstream_span = 10; + + // The operation name of the span which will be used for tracing. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + // + // This field will take precedence over and make following settings ineffective: + // + // * :ref:`route decorator ` and + // * :ref:`x-envoy-decorator-operation ` + // header will be ignored. + string operation = 11; + + // The operation name of the upstream span which will be used for tracing. + // This only takes effect when ``spawn_upstream_span`` is set to true and the upstream + // span is created. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + string upstream_operation = 12; } message InternalAddressConfig { @@ -262,18 +289,26 @@ message HttpConnectionManager { bool uri = 5; } + // The configuration for forwarding client cert details. + message ForwardClientCertConfig { + // How to handle the XFCC header. + ForwardClientCertDetails forward_client_cert_details = 1; + + // How to set the current client cert details. + SetCurrentClientCertDetails set_current_client_cert_details = 2; + } + // The configuration for HTTP upgrades. // For each upgrade type desired, an UpgradeConfig must be added. // // .. warning:: // - // The current implementation of upgrade headers does not handle - // multi-valued upgrade headers. Support for multi-valued headers may be - // added in the future if needed. + // The current implementation of upgrade headers does not handle multi-valued upgrade headers. Support for + // multi-valued headers may be added in the future if needed. // // .. warning:: - // The current implementation of upgrade headers does not work with HTTP/2 - // upstreams. + // The current implementation of upgrade headers does not work with HTTP/2 upstreams. + // message UpgradeConfig { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.http_connection_manager.v2.HttpConnectionManager." @@ -305,7 +340,10 @@ message HttpConnectionManager { // `) will apply to the ``:path`` header // destined for the upstream. // - // Note: access logging and tracing will show the original ``:path`` header. + // .. note:: + // + // Access logging and tracing will show the original ``:path`` header. + // message PathNormalizationOptions { // [#not-implemented-hide:] Normalization applies internally before any processing of requests by // HTTP filters, routing, and matching *and* will affect the forwarded ``:path`` header. Defaults @@ -443,10 +481,25 @@ message HttpConnectionManager { Tracing tracing = 7; // Additional settings for HTTP requests handled by the connection manager. These will be - // applicable to both HTTP1 and HTTP2 requests. + // applicable to both HTTP/1.1 and HTTP/2 requests. config.core.v3.HttpProtocolOptions common_http_protocol_options = 35 [(udpa.annotations.security).configure_for_untrusted_downstream = true]; + // If set to ``true``, Envoy will not initiate an immediate drain timer for downstream HTTP/1 connections + // once :ref:`common_http_protocol_options.max_connection_duration + // ` is exceeded. + // Instead, Envoy will wait until the next downstream request arrives, add a ``connection: close`` header + // to the response, and then gracefully close the connection once the stream has completed. + // + // This behavior adheres to `RFC 9112, Section 9.6 `_. + // + // If set to ``false``, exceeding ``max_connection_duration`` triggers Envoy's default drain behavior for HTTP/1, + // where the connection is eventually closed after all active streams finish. + // + // This option has no effect if ``max_connection_duration`` is not configured. + // Defaults to ``false``. + bool http1_safe_max_connection_duration = 58; + // Additional HTTP/1 settings that are passed to the HTTP/1 codec. // [#comment:TODO: The following fields are ignored when the // :ref:`header validation configuration ` @@ -459,7 +512,6 @@ message HttpConnectionManager { [(udpa.annotations.security).configure_for_untrusted_downstream = true]; // Additional HTTP/3 settings that are passed directly to the HTTP/3 codec. - // [#not-implemented-hide:] config.core.v3.Http3ProtocolOptions http3_protocol_options = 44; // An optional override that the connection manager will write to the server @@ -480,7 +532,16 @@ message HttpConnectionManager { // The maximum request headers size for incoming connections. // If unconfigured, the default max request headers allowed is 60 KiB. + // The default value can be overridden by setting runtime key ``envoy.reloadable_features.max_request_headers_size_kb``. // Requests that exceed this limit will receive a 431 response. + // + // .. note:: + // + // Currently some protocol codecs impose limits on the maximum size of a single header. + // + // * HTTP/2 (when using nghttp2) limits a single header to around 100kb. + // * HTTP/3 limits a single header to around 1024kb. + // google.protobuf.UInt32Value max_request_headers_kb = 29 [(validate.rules).uint32 = {lte: 8192 gt: 0}]; @@ -501,16 +562,6 @@ message HttpConnectionManager { // is terminated with a 408 Request Timeout error code if no upstream response // header has been received, otherwise a stream reset occurs. // - // This timeout also specifies the amount of time that Envoy will wait for the peer to open enough - // window to write any remaining stream data once the entirety of stream data (local end stream is - // true) has been buffered pending available window. In other words, this timeout defends against - // a peer that does not release enough window to completely write the stream, even though all - // data has been proxied within available flow control windows. If the timeout is hit in this - // case, the :ref:`tx_flush_timeout ` counter will be - // incremented. Note that :ref:`max_stream_duration - // ` does not apply to - // this corner case. - // // If the :ref:`overload action ` "envoy.overload_actions.reduce_timeouts" // is configured, this timeout is scaled according to the value for // :ref:`HTTP_DOWNSTREAM_STREAM_IDLE `. @@ -523,9 +574,29 @@ message HttpConnectionManager { // // A value of 0 will completely disable the connection manager stream idle // timeout, although per-route idle timeout overrides will continue to apply. + // + // This timeout is also used as the default value for :ref:`stream_flush_timeout + // `. google.protobuf.Duration stream_idle_timeout = 24 [(udpa.annotations.security).configure_for_untrusted_downstream = true]; + // The stream flush timeout for connections managed by the connection manager. + // + // If not specified, the value of stream_idle_timeout is used. This is for backwards compatibility + // since this was the original behavior. In essence this timeout is an override for the + // stream_idle_timeout that applies specifically to the end of stream flush case. + // + // This timeout specifies the amount of time that Envoy will wait for the peer to open enough + // window to write any remaining stream data once the entirety of stream data (local end stream is + // true) has been buffered pending available window. In other words, this timeout defends against + // a peer that does not release enough window to completely write the stream, even though all + // data has been proxied within available flow control windows. If the timeout is hit in this + // case, the :ref:`tx_flush_timeout ` counter will be + // incremented. Note that :ref:`max_stream_duration + // ` does not apply to + // this corner case. + google.protobuf.Duration stream_flush_timeout = 59; + // The amount of time that Envoy will wait for the entire request to be received. // The timer is activated when the request is initiated, and is disarmed when the last byte of the // request is sent upstream (i.e. all decoding filters have processed the request), OR when the @@ -547,9 +618,10 @@ message HttpConnectionManager { // race with the final GOAWAY frame. During this grace period, Envoy will // continue to accept new streams. After the grace period, a final GOAWAY // frame is sent and Envoy will start refusing new streams. Draining occurs - // both when a connection hits the idle timeout or during general server - // draining. The default grace period is 5000 milliseconds (5 seconds) if this - // option is not specified. + // either when a connection hits the idle timeout, when :ref:`max_connection_duration + // ` + // is reached, or during general server draining. The default grace period is + // 5000 milliseconds (5 seconds) if this option is not specified. google.protobuf.Duration drain_timeout = 12; // The delayed close timeout is for downstream connections managed by the HTTP connection manager. @@ -557,57 +629,67 @@ message HttpConnectionManager { // during which Envoy will wait for the peer to close (i.e., a TCP FIN/RST is received by Envoy // from the downstream connection) prior to Envoy closing the socket associated with that // connection. - // NOTE: This timeout is enforced even when the socket associated with the downstream connection - // is pending a flush of the write buffer. However, any progress made writing data to the socket - // will restart the timer associated with this timeout. This means that the total grace period for - // a socket in this state will be - // +. + // + // .. note:: + // + // This timeout is enforced even when the socket associated with the downstream connection is pending a flush of + // the write buffer. However, any progress made writing data to the socket will restart the timer associated with + // this timeout. This means that the total grace period for a socket in this state will be + // +. // // Delaying Envoy's connection close and giving the peer the opportunity to initiate the close // sequence mitigates a race condition that exists when downstream clients do not drain/process // data in a connection's receive buffer after a remote close has been detected via a socket - // write(). This race leads to such clients failing to process the response code sent by Envoy, + // ``write()``. This race leads to such clients failing to process the response code sent by Envoy, // which could result in erroneous downstream processing. // // If the timeout triggers, Envoy will close the connection's socket. // // The default timeout is 1000 ms if this option is not specified. // - // .. NOTE:: + // .. note:: // To be useful in avoiding the race condition described above, this timeout must be set // to *at least* +<100ms to account for // a reasonable "worst" case processing time for a full iteration of Envoy's event loop>. // - // .. WARNING:: - // A value of 0 will completely disable delayed close processing. When disabled, the downstream + // .. warning:: + // A value of ``0`` will completely disable delayed close processing. When disabled, the downstream // connection's socket will be closed immediately after the write flush is completed or will // never close if the write flush does not complete. + // google.protobuf.Duration delayed_close_timeout = 26; // Configuration for :ref:`HTTP access logs ` // emitted by the connection manager. repeated config.accesslog.v3.AccessLog access_log = 13; + // The interval to flush the above access logs. + // // .. attention:: - // This field is deprecated in favor of - // :ref:`access_log_flush_interval - // `. - // Note that if both this field and :ref:`access_log_flush_interval - // ` - // are specified, the former (deprecated field) is ignored. + // + // This field is deprecated in favor of + // :ref:`access_log_flush_interval + // `. + // Note that if both this field and :ref:`access_log_flush_interval + // ` + // are specified, the former (deprecated field) is ignored. google.protobuf.Duration access_log_flush_interval = 54 [ deprecated = true, (validate.rules).duration = {gte {nanos: 1000000}}, (envoy.annotations.deprecated_at_minor_version) = "3.0" ]; + // If set to true, HCM will flush an access log once when a new HTTP request is received, after the request + // headers have been evaluated, and before iterating through the HTTP filter chain. + // // .. attention:: - // This field is deprecated in favor of - // :ref:`flush_access_log_on_new_request - // `. - // Note that if both this field and :ref:`flush_access_log_on_new_request - // ` - // are specified, the former (deprecated field) is ignored. + // + // This field is deprecated in favor of + // :ref:`flush_access_log_on_new_request + // `. + // Note that if both this field and :ref:`flush_access_log_on_new_request + // ` + // are specified, the former (deprecated field) is ignored. bool flush_access_log_on_new_request = 55 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; @@ -631,20 +713,19 @@ message HttpConnectionManager { // :ref:`config_http_conn_man_headers_x-forwarded-for` for more information. uint32 xff_num_trusted_hops = 19; - // The configuration for the original IP detection extensions. + // Configuration for original IP detection extensions. // - // When configured the extensions will be called along with the request headers - // and information about the downstream connection, such as the directly connected address. - // Each extension will then use these parameters to decide the request's effective remote address. - // If an extension fails to detect the original IP address and isn't configured to reject - // the request, the HCM will try the remaining extensions until one succeeds or rejects - // the request. If the request isn't rejected nor any extension succeeds, the HCM will - // fallback to using the remote address. + // When these extensions are configured, Envoy will invoke them with the incoming request headers and + // details about the downstream connection, including the directly connected address. Each extension uses + // this information to determine the effective remote IP address for the request. If an extension cannot + // identify the original IP address and isn't set to reject the request, Envoy will sequentially attempt + // the remaining extensions until one successfully determines the IP or explicitly rejects the request. + // If all extensions fail without rejection, Envoy defaults to using the directly connected remote address. // - // .. WARNING:: - // Extensions cannot be used in conjunction with :ref:`use_remote_address + // .. warning:: + // These extensions cannot be configured simultaneously with :ref:`use_remote_address // ` - // nor :ref:`xff_num_trusted_hops + // or :ref:`xff_num_trusted_hops // `. // // [#extension-category: envoy.http.original_ip_detection] @@ -663,6 +744,34 @@ message HttpConnectionManager { // purposes. If unspecified, only RFC1918 IP addresses will be considered internal. // See the documentation for :ref:`config_http_conn_man_headers_x-envoy-internal` for more // information about internal/external addresses. + // + // .. warning:: + // As of Envoy 1.33.0 no IP addresses will be considered trusted. If you have tooling such as probes + // on your private network which need to be treated as trusted (e.g. changing arbitrary x-envoy headers) + // you will have to manually include those addresses or CIDR ranges like: + // + // .. validated-code-block:: yaml + // :type-name: envoy.extensions.filters.network.http_connection_manager.v3.InternalAddressConfig + // + // cidr_ranges: + // address_prefix: 10.0.0.0 + // prefix_len: 8 + // cidr_ranges: + // address_prefix: 192.168.0.0 + // prefix_len: 16 + // cidr_ranges: + // address_prefix: 172.16.0.0 + // prefix_len: 12 + // cidr_ranges: + // address_prefix: 127.0.0.1 + // prefix_len: 32 + // cidr_ranges: + // address_prefix: fd00:: + // prefix_len: 8 + // cidr_ranges: + // address_prefix: ::1 + // prefix_len: 128 + // InternalAddressConfig internal_address_config = 25; // If set, Envoy will not append the remote address to the @@ -710,6 +819,53 @@ message HttpConnectionManager { // value. SetCurrentClientCertDetails set_current_client_cert_details = 17; + // The matcher for forwarding client cert details. This allows per-request configuration + // of forward client cert behavior based on request properties. If a matcher is configured + // and matches a request, the matched action's forward client cert config will be used. + // If the matcher is not configured or doesn't match, the static + // :ref:`forward_client_cert_details + // ` + // and + // :ref:`set_current_client_cert_details + // ` + // config will be used as fallback. + // + // Example: If the x-forwarded-client-cert header contains "trusted-client", use APPEND_FORWARD, + // otherwise use SANITIZE_SET: + // + // .. code-block:: yaml + // + // forward_client_cert_matcher: + // matcher_list: + // matchers: + // - predicate: + // single_predicate: + // input: + // name: envoy.matching.inputs.request_headers + // typed_config: + // "@type": type.googleapis.com/envoy.type.matcher.v3.HttpRequestHeaderMatchInput + // header_name: "x-forwarded-client-cert" + // value_match: + // string_match: + // contains: "trusted-client" + // on_match: + // action: + // name: forward_client_cert + // typed_config: + // "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager.ForwardClientCertConfig + // forward_client_cert_details: APPEND_FORWARD + // set_current_client_cert_details: + // uri: true + // on_no_match: + // action: + // name: forward_client_cert + // typed_config: + // "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager.ForwardClientCertConfig + // forward_client_cert_details: SANITIZE_SET + // set_current_client_cert_details: + // uri: true + xds.type.matcher.v3.Matcher forward_client_cert_matcher = 60; + // If proxy_100_continue is true, Envoy will proxy incoming "Expect: // 100-continue" headers upstream, and forward "100 Continue" responses // downstream. If this is false or not set, Envoy will instead strip the @@ -972,7 +1128,7 @@ message Rds { "envoy.config.filter.network.http_connection_manager.v2.Rds"; // Configuration source specifier for RDS. - config.core.v3.ConfigSource config_source = 1 [(validate.rules).message = {required: true}]; + config.core.v3.ConfigSource config_source = 1; // The name of the route configuration. This name will be passed to the RDS // API. This allows an Envoy configuration with multiple HTTP listeners (and diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/call_credentials/access_token/v3/access_token_credentials.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/call_credentials/access_token/v3/access_token_credentials.proto new file mode 100644 index 00000000000..45ee3839e6f --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/call_credentials/access_token/v3/access_token_credentials.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package envoy.extensions.grpc_service.call_credentials.access_token.v3; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3"; +option java_outer_classname = "AccessTokenCredentialsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/grpc_service/call_credentials/access_token/v3;access_tokenv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: gRPC Access Token Credentials] + +// [#not-implemented-hide:] +message AccessTokenCredentials { + // The access token. + string token = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/google_default/v3/google_default_credentials.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/google_default/v3/google_default_credentials.proto new file mode 100644 index 00000000000..77c3af41fdd --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/google_default/v3/google_default_credentials.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package envoy.extensions.grpc_service.channel_credentials.google_default.v3; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3"; +option java_outer_classname = "GoogleDefaultCredentialsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/grpc_service/channel_credentials/google_default/v3;google_defaultv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: gRPC Google Default Credentials] + +// [#not-implemented-hide:] +message GoogleDefaultCredentials { +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/insecure/v3/insecure_credentials.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/insecure/v3/insecure_credentials.proto new file mode 100644 index 00000000000..70d58451e2d --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/insecure/v3/insecure_credentials.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package envoy.extensions.grpc_service.channel_credentials.insecure.v3; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.insecure.v3"; +option java_outer_classname = "InsecureCredentialsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/grpc_service/channel_credentials/insecure/v3;insecurev3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: gRPC Insecure Credentials] + +// [#not-implemented-hide:] +message InsecureCredentials { +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/local/v3/local_credentials.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/local/v3/local_credentials.proto new file mode 100644 index 00000000000..00514a0e847 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/local/v3/local_credentials.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package envoy.extensions.grpc_service.channel_credentials.local.v3; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.local.v3"; +option java_outer_classname = "LocalCredentialsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/grpc_service/channel_credentials/local/v3;localv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: gRPC Local Credentials] + +// [#not-implemented-hide:] +message LocalCredentials { +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/tls/v3/tls_credentials.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/tls/v3/tls_credentials.proto new file mode 100644 index 00000000000..f64c16bb684 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/tls/v3/tls_credentials.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package envoy.extensions.grpc_service.channel_credentials.tls.v3; + +import "envoy/extensions/transport_sockets/tls/v3/tls.proto"; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.tls.v3"; +option java_outer_classname = "TlsCredentialsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/grpc_service/channel_credentials/tls/v3;tlsv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: gRPC TLS Credentials] + +// [#not-implemented-hide:] +message TlsCredentials { + // The certificate provider instance for the root cert. Must be set. + transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance root_certificate_provider = + 1; + + // The certificate provider instance for the identity cert. Optional; + // if unset, no identity certificate will be sent to the server. + transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance + identity_certificate_provider = 2; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/xds/v3/xds_credentials.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/xds/v3/xds_credentials.proto new file mode 100644 index 00000000000..ba8d471dd49 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/xds/v3/xds_credentials.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +package envoy.extensions.grpc_service.channel_credentials.xds.v3; + +import "google/protobuf/any.proto"; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.xds.v3"; +option java_outer_classname = "XdsCredentialsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/grpc_service/channel_credentials/xds/v3;xdsv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: gRPC xDS Credentials] + +// [#not-implemented-hide:] +message XdsCredentials { + // Fallback credentials. Required. + google.protobuf.Any fallback_credentials = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3/client_side_weighted_round_robin.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3/client_side_weighted_round_robin.proto index c70360a0946..c55d30b89e0 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3/client_side_weighted_round_robin.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3/client_side_weighted_round_robin.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package envoy.extensions.load_balancing_policies.client_side_weighted_round_robin.v3; +import "envoy/extensions/load_balancing_policies/common/v3/common.proto"; + import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; @@ -15,7 +17,7 @@ option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/loa option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Client-Side Weighted Round Robin Load Balancing Policy] -// [#not-implemented-hide:] +// [#extension: envoy.load_balancing_policies.client_side_weighted_round_robin] // Configuration for the client_side_weighted_round_robin LB policy. // @@ -30,11 +32,19 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // regardless of result. Only failed queries count toward eps. A config // parameter error_utilization_penalty controls the penalty to adjust endpoint // weights using eps and qps. The weight of a given endpoint is computed as: -// qps / (utilization + eps/qps * error_utilization_penalty) +// ``qps / (utilization + eps/qps * error_utilization_penalty)``. +// +// Note that Envoy will forward the ORCA response headers/trailers from the upstream +// cluster to the downstream client. This means that if the downstream client is also +// configured to use ``client_side_weighted_round_robin`` it will load balance against +// Envoy based on upstream weights. This can happen when Envoy is used as a reverse proxy. +// To avoid this issue you can configure the :ref:`header_mutation filter ` to remove +// the ORCA payload from the response headers/trailers. // -// See the :ref:`load balancing architecture overview` for more information. +// See the :ref:`load balancing architecture +// overview` for more information. // -// [#next-free-field: 7] +// [#next-free-field: 9] message ClientSideWeightedRoundRobin { // Whether to enable out-of-band utilization reporting collection from // the endpoints. By default, per-request utilization reporting is used. @@ -68,4 +78,14 @@ message ClientSideWeightedRoundRobin { // calculated as eps/qps. Configuration is rejected if this value is negative. // Default is 1.0. google.protobuf.FloatValue error_utilization_penalty = 6 [(validate.rules).float = {gte: 0.0}]; + + // By default, endpoint weight is computed based on the :ref:`application_utilization ` field reported by the endpoint. + // If that field is not set, then utilization will instead be computed by taking the max of the values of the metrics specified here. + // For map fields in the ORCA proto, the string will be of the form ``.``. For example, the string ``named_metrics.foo`` will mean to look for the key ``foo`` in the ORCA :ref:`named_metrics ` field. + // If none of the specified metrics are present in the load report, then :ref:`cpu_utilization ` is used instead. + repeated string metric_names_for_computing_utilization = 7; + + // Configuration for slow start mode. + // If this configuration is not set, slow start will not be not enabled. + common.v3.SlowStartConfig slow_start_config = 8; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/common/v3/common.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/common/v3/common.proto index 51520690a29..22faf11b9c5 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/common/v3/common.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/common/v3/common.proto @@ -3,11 +3,13 @@ syntax = "proto3"; package envoy.extensions.load_balancing_policies.common.v3; import "envoy/config/core/v3/base.proto"; +import "envoy/config/route/v3/route_components.proto"; import "envoy/type/v3/percent.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; +import "envoy/annotations/deprecation.proto"; import "udpa/annotations/status.proto"; import "validate/validate.proto"; @@ -22,7 +24,34 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; message LocalityLbConfig { // Configuration for :ref:`zone aware routing // `. + // [#next-free-field: 7] message ZoneAwareLbConfig { + // Basis for computing per-locality percentages in zone-aware routing. + enum LocalityBasis { + // Use the number of healthy hosts in each locality. + HEALTHY_HOSTS_NUM = 0; + + // Use the weights of healthy hosts in each locality. + HEALTHY_HOSTS_WEIGHT = 1; + } + + // Configures Envoy to always route requests to the local zone regardless of the + // upstream zone structure. In Envoy's default configuration, traffic is distributed proportionally + // across all upstream hosts while trying to maximize local routing when possible. The approach + // with force_local_zone aims to be more predictable and if there are upstream hosts in the local + // zone, they will receive all traffic. + // * :ref:`runtime values `. + // * :ref:`Zone aware routing support `. + message ForceLocalZone { + // Configures the minimum number of upstream hosts in the local zone required when force_local_zone + // is enabled. If the number of upstream hosts in the local zone is less than the specified value, + // Envoy will fall back to the default proportional-based distribution across localities. + // If not specified, the default is 1. + // * :ref:`runtime values `. + // * :ref:`Zone aware routing support `. + google.protobuf.UInt32Value min_size = 1; + } + // Configures percentage of requests that will be considered for zone aware routing // if zone aware routing is configured. If not specified, the default is 100%. // * :ref:`runtime values `. @@ -41,6 +70,18 @@ message LocalityLbConfig { // requests as if all hosts are unhealthy. This can help avoid potentially overwhelming a // failing service. bool fail_traffic_on_panic = 3; + + // If set to true, Envoy will force LocalityDirect routing if a local locality exists. + bool force_locality_direct_routing = 4 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + ForceLocalZone force_local_zone = 5; + + // Determines how locality percentages are computed: + // - HEALTHY_HOSTS_NUM: proportional to the count of healthy hosts. + // - HEALTHY_HOSTS_WEIGHT: proportional to the weights of healthy hosts. + // Default value is HEALTHY_HOSTS_NUM if unset. + LocalityBasis locality_basis = 6; } // Configuration for :ref:`locality weighted load balancing @@ -111,4 +152,10 @@ message ConsistentHashingLbConfig { // This is an O(N) algorithm, unlike other load balancers. Using a lower ``hash_balance_factor`` results in more hosts // being probed, so use a higher value if you require better performance. google.protobuf.UInt32Value hash_balance_factor = 2 [(validate.rules).uint32 = {gte: 100}]; + + // Specifies a list of hash policies to use for ring hash load balancing. If ``hash_policy`` is + // set, then + // :ref:`route level hash policy ` + // will be ignored. + repeated config.route.v3.RouteAction.HashPolicy hash_policy = 3; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto index ab8367a401a..e2e4ade8236 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto @@ -14,7 +14,7 @@ option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/loa option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Weighted Round Robin Locality-Picking Load Balancing Policy] -// [#not-implemented-hide:] +// [#extension: envoy.load_balancing_policies.wrr_locality] // Configuration for the wrr_locality LB policy. See the :ref:`load balancing architecture overview // ` for more information. diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/http_11_proxy/v3/upstream_http_11_connect.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/http_11_proxy/v3/upstream_http_11_connect.proto new file mode 100644 index 00000000000..2c9b5333f41 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/http_11_proxy/v3/upstream_http_11_connect.proto @@ -0,0 +1,38 @@ +syntax = "proto3"; + +package envoy.extensions.transport_sockets.http_11_proxy.v3; + +import "envoy/config/core/v3/base.proto"; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.transport_sockets.http_11_proxy.v3"; +option java_outer_classname = "UpstreamHttp11ConnectProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/http_11_proxy/v3;http_11_proxyv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Upstream HTTP/1.1 Proxy] +// [#extension: envoy.transport_sockets.http_11_proxy] + +// HTTP/1.1 proxy transport socket establishes an upstream connection to a proxy address +// instead of the target host's address. This behavior is triggered when the transport +// socket is configured and proxy information is provided. +// +// Behavior when proxying: +// ======================= +// When an upstream connection is established, instead of connecting directly to the endpoint +// address, the client will connect to the specified proxy address, send an HTTP/1.1 ``CONNECT`` request +// indicating the endpoint address, and process the response. If the response has HTTP status 200, +// the connection will be passed down to the underlying transport socket. +// +// Configuring proxy information: +// ============================== +// Set ``typed_filter_metadata`` in :ref:`LbEndpoint.Metadata ` or :ref:`LocalityLbEndpoints.Metadata `. +// using the key ``envoy.http11_proxy_transport_socket.proxy_address`` and the +// proxy address in ``config::core::v3::Address`` format. +// +message Http11ProxyUpstreamTransport { + // The underlying transport socket being wrapped. Defaults to plaintext (raw_buffer) if unset. + config.core.v3.TransportSocket transport_socket = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto index c1a3f5b33b3..9bc5fb5d029 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto @@ -24,7 +24,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Common TLS configuration] -// [#next-free-field: 6] +// [#next-free-field: 7] message TlsParameters { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.TlsParameters"; @@ -45,6 +45,23 @@ message TlsParameters { TLSv1_3 = 4; } + enum CompliancePolicy { + // FIPS_202205 configures a TLS connection to use: + // + // * TLS 1.2 or 1.3 + // * For TLS 1.2, only ECDHE_[RSA|ECDSA]_WITH_AES_*_GCM_SHA*. + // * For TLS 1.3, only AES-GCM + // * P-256 or P-384 for key agreement. + // * For server signatures, only ``PKCS#1/PSS`` with ``SHA256/384/512``, or ECDSA + // with P-256 or P-384. + // + // .. attention:: + // + // Please refer to `BoringSSL policies `_ + // for details. + FIPS_202205 = 0; + } + // Minimum TLS protocol version. By default, it's ``TLSv1_2`` for both clients and servers. // // TLS protocol versions below TLSv1_2 require setting compatible ciphers with the @@ -157,6 +174,11 @@ message TlsParameters { // rsa_pkcs1_sha1 // ecdsa_sha1 repeated string signature_algorithms = 5; + + // Compliance policies configure various aspects of the TLS based on the given policy. + // The policies are applied last during configuration and may override the other TLS + // parameters, or any previous policy. + repeated CompliancePolicy compliance_policies = 6 [(validate.rules).repeated = {max_items: 1}]; } // BoringSSL private key method configuration. The private key methods are used for external @@ -232,12 +254,13 @@ message TlsCertificate { config.core.v3.WatchedDirectory watched_directory = 7; // BoringSSL private key method provider. This is an alternative to :ref:`private_key - // ` field. This can't be - // marked as ``oneof`` due to API compatibility reasons. Setting both :ref:`private_key - // ` and - // :ref:`private_key_provider - // ` fields will result in an - // error. + // ` field. + // When both :ref:`private_key ` and + // :ref:`private_key_provider ` fields are set, + // ``private_key_provider`` takes precedence. + // If ``private_key_provider`` is unavailable and :ref:`fallback + // ` + // is enabled, ``private_key`` will be used. PrivateKeyProvider private_key_provider = 6; // The password to decrypt the TLS private key. If this field is not set, it is assumed that the @@ -290,12 +313,12 @@ message TlsSessionTicketKeys { // respect to the TLS handshake. // [#not-implemented-hide:] message CertificateProviderPluginInstance { - // Provider instance name. If not present, defaults to "default". + // Provider instance name. // // Instance names should generally be defined not in terms of the underlying provider // implementation (e.g., "file_watcher") but rather in terms of the function of the // certificates (e.g., "foo_deployment_identity"). - string instance_name = 1; + string instance_name = 1 [(validate.rules).string = {min_len: 1}]; // Opaque name used to specify certificate instances or types. For example, "ROOTCA" to specify // a root-certificate (validation context) or "example.com" to specify a certificate for a @@ -322,6 +345,13 @@ message SubjectAltNameMatcher { // Matcher for SAN value. // + // If the :ref:`san_type ` + // is :ref:`DNS ` + // and the matcher type is :ref:`exact `, DNS wildcards are evaluated + // according to the rules in https://www.rfc-editor.org/rfc/rfc6125#section-6.4.3. + // For example, ``*.example.com`` would match ``test.example.com`` but not ``example.com`` and not + // ``a.b.example.com``. + // // The string matching for OTHER_NAME SAN values depends on their ASN.1 type: // // * OBJECT: Validated against its dotted numeric notation (e.g., "1.2.3.4") diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/secret.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/secret.proto index 83ad364c4bf..94660e2da9f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/secret.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/secret.proto @@ -22,8 +22,13 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; message GenericSecret { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.GenericSecret"; - // Secret of generic type and is available to filters. + // Secret of generic type and is available to filters. It is expected + // that only only one of secret and secrets is set. config.core.v3.DataSource secret = 1 [(udpa.annotations.sensitive) = true]; + + // For cases where multiple associated secrets need to be distributed together. It is expected + // that only only one of secret and secrets is set. + map secrets = 2 [(udpa.annotations.sensitive) = true]; } message SdsSecretConfig { diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto index 9d465c97321..d656c66b5d0 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto @@ -25,7 +25,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#extension: envoy.transport_sockets.tls] // The TLS contexts below provide the transport socket configuration for upstream/downstream TLS. -// [#next-free-field: 6] +// [#next-free-field: 8] message UpstreamTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.UpstreamTlsContext"; @@ -34,14 +34,32 @@ message UpstreamTlsContext { // // .. attention:: // - // Server certificate verification is not enabled by default. Configure - // :ref:`trusted_ca` to enable - // verification. + // Server certificate verification is not enabled by default. To enable verification, configure + // :ref:`trusted_ca`. CommonTlsContext common_tls_context = 1; // SNI string to use when creating TLS backend connections. string sni = 2 [(validate.rules).string = {max_bytes: 255}]; + // If true, replaces the SNI for the connection with the hostname of the upstream host, if + // the hostname is known due to either a DNS cluster type or the + // :ref:`hostname ` is set on + // the host. + // + // See :ref:`SNI configuration ` for details on how this + // interacts with other validation options. + bool auto_host_sni = 6; + + // If true, replaces any Subject Alternative Name (SAN) validations with a validation for a DNS SAN matching + // the SNI value sent. The validation uses the actual requested SNI, regardless of how the SNI is configured. + // + // For common cases where an SNI value is present and the server certificate should include a corresponding SAN, + // this option ensures the SAN is properly validated. + // + // See the :ref:`validation configuration ` for how this interacts with + // other validation options. + bool auto_sni_san_validation = 7; + // If true, server-initiated TLS renegotiation will be allowed. // // .. attention:: @@ -50,43 +68,38 @@ message UpstreamTlsContext { bool allow_renegotiation = 3; // Maximum number of session keys (Pre-Shared Keys for TLSv1.3+, Session IDs and Session Tickets - // for TLSv1.2 and older) to store for the purpose of session resumption. + // for TLSv1.2 and older) to be stored for session resumption. // // Defaults to 1, setting this to 0 disables session resumption. google.protobuf.UInt32Value max_session_keys = 4; - // This field is used to control the enforcement, whereby the handshake will fail if the keyUsage extension - // is present and incompatible with the TLS usage. Currently, the default value is false (i.e., enforcement off) - // but it is expected to be changed to true by default in a future release. - // ``ssl.was_key_usage_invalid`` in :ref:`listener metrics ` will be set for certificate - // configurations that would fail if this option were set to true. + // Controls enforcement of the ``keyUsage`` extension in peer certificates. If set to ``true``, the handshake will fail if + // the ``keyUsage`` is incompatible with TLS usage. + // + // .. note:: + // The default value is ``false`` (i.e., enforcement off). It is expected to change to ``true`` in a future release. + // + // The ``ssl.was_key_usage_invalid`` in :ref:`listener metrics ` metric will be incremented + // for configurations that would fail if this option were enabled. google.protobuf.BoolValue enforce_rsa_key_usage = 5; } -// [#next-free-field: 11] +// [#next-free-field: 12] message DownstreamTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.DownstreamTlsContext"; enum OcspStaplePolicy { - // OCSP responses are optional. If an OCSP response is absent - // or expired, the associated certificate will be used for - // connections without an OCSP staple. + // OCSP responses are optional. If absent or expired, the certificate is used without stapling. LENIENT_STAPLING = 0; - // OCSP responses are optional. If an OCSP response is absent, - // the associated certificate will be used without an - // OCSP staple. If a response is provided but is expired, - // the associated certificate will not be used for - // subsequent connections. If no suitable certificate is found, - // the connection is rejected. + // OCSP responses are optional. If absent, the certificate is used without stapling. If present but expired, + // the certificate is not used for subsequent connections. Connections are rejected if no suitable certificate + // is found. STRICT_STAPLING = 1; - // OCSP responses are required. Configuration will fail if - // a certificate is provided without an OCSP response. If a - // response expires, the associated certificate will not be - // used connections. If no suitable certificate is found, the - // connection is rejected. + // OCSP responses are required. Connections fail if a certificate lacks a valid OCSP response. Expired responses + // prevent certificate use in new connections, and connections are rejected if no suitable certificate is available. MUST_STAPLE = 2; } @@ -119,51 +132,64 @@ message DownstreamTlsContext { bool disable_stateless_session_resumption = 7; } - // If set to true, the TLS server will not maintain a session cache of TLS sessions. (This is - // relevant only for TLSv1.2 and earlier.) + // If ``true``, the TLS server will not maintain a session cache of TLS sessions. + // + // .. note:: + // This applies only to TLSv1.2 and earlier. + // bool disable_stateful_session_resumption = 10; - // If specified, ``session_timeout`` will change the maximum lifetime (in seconds) of the TLS session. - // Currently this value is used as a hint for the `TLS session ticket lifetime (for TLSv1.2) `_. - // Only seconds can be specified (fractional seconds are ignored). + // Maximum lifetime of TLS sessions. If specified, ``session_timeout`` will change the maximum lifetime + // of the TLS session. + // + // This serves as a hint for the `TLS session ticket lifetime (for TLSv1.2) `_. + // Only whole seconds are considered; fractional seconds are ignored. google.protobuf.Duration session_timeout = 6 [(validate.rules).duration = { lt {seconds: 4294967296} gte {} }]; - // Config for whether to use certificates if they do not have - // an accompanying OCSP response or if the response expires at runtime. - // Defaults to LENIENT_STAPLING + // Configuration for handling certificates without an OCSP response or with expired responses. + // + // Defaults to ``LENIENT_STAPLING`` OcspStaplePolicy ocsp_staple_policy = 8 [(validate.rules).enum = {defined_only: true}]; // Multiple certificates are allowed in Downstream transport socket to serve different SNI. - // If the client provides SNI but no such cert matched, it will decide to full scan certificates or not based on this config. - // Defaults to false. See more details in :ref:`Multiple TLS certificates `. + // This option controls the behavior when no matching certificate is found for the received SNI value, + // or no SNI value was sent. If enabled, all certificates will be evaluated for a match for non-SNI criteria + // such as key type and OCSP settings. If disabled, the first provided certificate will be used. + // Defaults to ``false``. See more details in :ref:`Multiple TLS certificates `. google.protobuf.BoolValue full_scan_certs_on_sni_mismatch = 9; + + // If ``true``, the downstream client's preferred cipher is used during the handshake. If ``false``, Envoy + // uses its preferred cipher. + // + // .. note:: + // This has no effect when using TLSv1_3. + // + bool prefer_client_ciphers = 11; } // TLS key log configuration. // The key log file format is "format used by NSS for its SSLKEYLOGFILE debugging output" (text taken from openssl man page) message TlsKeyLog { - // The path to save the TLS key log. + // Path to save the TLS key log. string path = 1 [(validate.rules).string = {min_len: 1}]; - // The local IP address that will be used to filter the connection which should save the TLS key log - // If it is not set, any local IP address will be matched. + // Local IP address ranges to filter connections for TLS key logging. If not set, matches any local IP address. repeated config.core.v3.CidrRange local_address_range = 2; - // The remote IP address that will be used to filter the connection which should save the TLS key log - // If it is not set, any remote IP address will be matched. + // Remote IP address ranges to filter connections for TLS key logging. If not set, matches any remote IP address. repeated config.core.v3.CidrRange remote_address_range = 3; } // TLS context shared by both client and server TLS contexts. -// [#next-free-field: 16] +// [#next-free-field: 17] message CommonTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.CommonTlsContext"; - // Config for Certificate provider to get certificates. This provider should allow certificates to be - // fetched/refreshed over the network asynchronously with respect to the TLS handshake. + // Config for the Certificate Provider to fetch certificates. Certificates are fetched/refreshed asynchronously over + // the network relative to the TLS handshake. // // DEPRECATED: This message is not currently used, but if we ever do need it, we will want to // move it out of CommonTlsContext and into common.proto, similar to the existing @@ -256,7 +282,7 @@ message CommonTlsContext { // fetched/refreshed over the network asynchronously with respect to the TLS handshake. // // The same number and types of certificates as :ref:`tls_certificates ` - // are valid in the the certificates fetched through this setting. + // are valid in the certificates fetched through this setting. // // If ``tls_certificates`` or ``tls_certificate_provider_instance`` are set, this field // is ignored. @@ -269,6 +295,14 @@ message CommonTlsContext { // [#not-implemented-hide:] CertificateProviderPluginInstance tls_certificate_provider_instance = 14; + // Custom TLS certificate selector. + // + // Select TLS certificate based on TLS client hello. + // If empty, defaults to native TLS certificate selection behavior: + // DNS SANs or Subject Common Name in TLS certificates is extracted as server name pattern to match SNI. + // [#extension-category: envoy.tls.certificate_selectors] + config.core.v3.TypedExtensionConfig custom_tls_certificate_selector = 16; + // Certificate provider for fetching TLS certificates. // [#not-implemented-hide:] CertificateProvider tls_certificate_certificate_provider = 9 @@ -287,13 +321,17 @@ message CommonTlsContext { // fetched/refreshed over the network asynchronously with respect to the TLS handshake. SdsSecretConfig validation_context_sds_secret_config = 7; - // Combined certificate validation context holds a default CertificateValidationContext - // and SDS config. When SDS server returns dynamic CertificateValidationContext, both dynamic - // and default CertificateValidationContext are merged into a new CertificateValidationContext - // for validation. This merge is done by Message::MergeFrom(), so dynamic - // CertificateValidationContext overwrites singular fields in default - // CertificateValidationContext, and concatenates repeated fields to default - // CertificateValidationContext, and logical OR is applied to boolean fields. + // Combines the default ``CertificateValidationContext`` with the SDS-provided dynamic context for certificate + // validation. + // + // When the SDS server returns a dynamic ``CertificateValidationContext``, it is merged + // with the default context using ``Message::MergeFrom()``. The merging rules are as follows: + // + // * **Singular Fields:** Dynamic fields override the default singular fields. + // * **Repeated Fields:** Dynamic repeated fields are concatenated with the default repeated fields. + // * **Boolean Fields:** Boolean fields are combined using a logical OR operation. + // + // The resulting ``CertificateValidationContext`` is used to perform certificate validation. CombinedCertificateValidationContext combined_validation_context = 8; // Certificate provider for fetching validation context. diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/auth/v3/attribute_context.proto b/xds/third_party/envoy/src/main/proto/envoy/service/auth/v3/attribute_context.proto new file mode 100644 index 00000000000..2c4fbb4b73e --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/service/auth/v3/attribute_context.proto @@ -0,0 +1,222 @@ +syntax = "proto3"; + +package envoy.service.auth.v3; + +import "envoy/config/core/v3/address.proto"; +import "envoy/config/core/v3/base.proto"; + +import "google/protobuf/timestamp.proto"; + +import "udpa/annotations/migrate.proto"; +import "udpa/annotations/status.proto"; +import "udpa/annotations/versioning.proto"; + +option java_package = "io.envoyproxy.envoy.service.auth.v3"; +option java_outer_classname = "AttributeContextProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3;authv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Attribute context] + +// See :ref:`network filter configuration overview ` +// and :ref:`HTTP filter configuration overview `. + +// An attribute is a piece of metadata that describes an activity on a network. +// For example, the size of an HTTP request, or the status code of an HTTP response. +// +// Each attribute has a type and a name, which is logically defined as a proto message field +// of the ``AttributeContext``. The ``AttributeContext`` is a collection of individual attributes +// supported by Envoy authorization system. +// [#comment: The following items are left out of this proto +// Request.Auth field for JWTs +// Request.Api for api management +// Origin peer that originated the request +// Caching Protocol +// request_context return values to inject back into the filter chain +// peer.claims -- from X.509 extensions +// Configuration +// - field mask to send +// - which return values from request_context are copied back +// - which return values are copied into request_headers] +// [#next-free-field: 14] +message AttributeContext { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.AttributeContext"; + + // This message defines attributes for a node that handles a network request. + // The node can be either a service or an application that sends, forwards, + // or receives the request. Service peers should fill in the ``service``, + // ``principal``, and ``labels`` as appropriate. + // [#next-free-field: 6] + message Peer { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.AttributeContext.Peer"; + + // The address of the peer, this is typically the IP address. + // It can also be UDS path, or others. + config.core.v3.Address address = 1; + + // The canonical service name of the peer. + // It should be set to :ref:`the HTTP x-envoy-downstream-service-cluster + // ` + // If a more trusted source of the service name is available through mTLS/secure naming, it + // should be used. + string service = 2; + + // The labels associated with the peer. + // These could be pod labels for Kubernetes or tags for VMs. + // The source of the labels could be an X.509 certificate or other configuration. + map labels = 3; + + // The authenticated identity of this peer. + // For example, the identity associated with the workload such as a service account. + // If an X.509 certificate is used to assert the identity this field should be sourced from + // ``URI Subject Alternative Names``, ``DNS Subject Alternate Names`` or ``Subject`` in that order. + // The primary identity should be the principal. The principal format is issuer specific. + // + // Examples: + // + // - SPIFFE format is ``spiffe://trust-domain/path``. + // - Google account format is ``https://accounts.google.com/{userid}``. + string principal = 4; + + // The X.509 certificate used to authenticate the identify of this peer. + // When present, the certificate contents are encoded in URL and PEM format. + string certificate = 5; + } + + // Represents a network request, such as an HTTP request. + message Request { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.AttributeContext.Request"; + + // The timestamp when the proxy receives the first byte of the request. + google.protobuf.Timestamp time = 1; + + // Represents an HTTP request or an HTTP-like request. + HttpRequest http = 2; + } + + // This message defines attributes for an HTTP request. + // HTTP/1.x, HTTP/2, gRPC are all considered as HTTP requests. + // [#next-free-field: 14] + message HttpRequest { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.AttributeContext.HttpRequest"; + + // The unique ID for a request, which can be propagated to downstream + // systems. The ID should have low probability of collision + // within a single day for a specific service. + // For HTTP requests, it should be X-Request-ID or equivalent. + string id = 1; + + // The HTTP request method, such as ``GET``, ``POST``. + string method = 2; + + // The HTTP request headers. If multiple headers share the same key, they + // must be merged according to the HTTP spec. All header keys must be + // lower-cased, because HTTP header keys are case-insensitive. + // Header value is encoded as UTF-8 string. Non-UTF-8 characters will be replaced by "!". + // This field will not be set if + // :ref:`encode_raw_headers ` + // is set to true. + map headers = 3 + [(udpa.annotations.field_migrate).oneof_promotion = "headers_type"]; + + // A list of the raw HTTP request headers. This is used instead of + // :ref:`headers ` when + // :ref:`encode_raw_headers ` + // is set to true. + // + // Note that this is not actually a map type. ``header_map`` contains a single repeated field + // ``headers``. + // + // Here, only the ``key`` and ``raw_value`` fields will be populated for each HeaderValue, and + // that is only when + // :ref:`encode_raw_headers ` + // is set to true. + // + // Also, unlike the + // :ref:`headers ` + // field, headers with the same key are not combined into a single comma separated header. + config.core.v3.HeaderMap header_map = 13 + [(udpa.annotations.field_migrate).oneof_promotion = "headers_type"]; + + // The request target, as it appears in the first line of the HTTP request. This includes + // the URL path and query-string. No decoding is performed. + string path = 4; + + // The HTTP request ``Host`` or ``:authority`` header value. + string host = 5; + + // The HTTP URL scheme, such as ``http`` and ``https``. + string scheme = 6; + + // This field is always empty, and exists for compatibility reasons. The HTTP URL query is + // included in ``path`` field. + string query = 7; + + // This field is always empty, and exists for compatibility reasons. The URL fragment is + // not submitted as part of HTTP requests; it is unknowable. + string fragment = 8; + + // The HTTP request size in bytes. If unknown, it must be -1. + int64 size = 9; + + // The network protocol used with the request, such as "HTTP/1.0", "HTTP/1.1", or "HTTP/2". + // + // See :repo:`headers.h:ProtocolStrings ` for a list of all + // possible values. + string protocol = 10; + + // The HTTP request body. + string body = 11; + + // The HTTP request body in bytes. This is used instead of + // :ref:`body ` when + // :ref:`pack_as_bytes ` + // is set to true. + bytes raw_body = 12; + } + + // This message defines attributes for the underlying TLS session. + message TLSSession { + // SNI used for TLS session. + string sni = 1; + } + + // The source of a network activity, such as starting a TCP connection. + // In a multi hop network activity, the source represents the sender of the + // last hop. + Peer source = 1; + + // The destination of a network activity, such as accepting a TCP connection. + // In a multi hop network activity, the destination represents the receiver of + // the last hop. + Peer destination = 2; + + // Represents a network request, such as an HTTP request. + Request request = 4; + + // This is analogous to http_request.headers, however these contents will not be sent to the + // upstream server. Context_extensions provide an extension mechanism for sending additional + // information to the auth server without modifying the proto definition. It maps to the + // internal opaque context in the filter chain. + map context_extensions = 10; + + // Dynamic metadata associated with the request. + config.core.v3.Metadata metadata_context = 11; + + // Metadata associated with the selected route. + config.core.v3.Metadata route_metadata_context = 13; + + // TLS session details of the underlying connection. + // This is not populated by default and will be populated only if the ext_authz filter has + // been specifically configured to include this information. + // For HTTP ext_authz, that requires :ref:`include_tls_session ` + // to be set to true. + // For network ext_authz, that requires :ref:`include_tls_session ` + // to be set to true. + TLSSession tls_session = 12; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/auth/v3/external_auth.proto b/xds/third_party/envoy/src/main/proto/envoy/service/auth/v3/external_auth.proto new file mode 100644 index 00000000000..520a4ff4f31 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/service/auth/v3/external_auth.proto @@ -0,0 +1,157 @@ +syntax = "proto3"; + +package envoy.service.auth.v3; + +import "envoy/config/core/v3/base.proto"; +import "envoy/service/auth/v3/attribute_context.proto"; +import "envoy/type/v3/http_status.proto"; + +import "google/protobuf/struct.proto"; +import "google/rpc/status.proto"; + +import "envoy/annotations/deprecation.proto"; +import "udpa/annotations/status.proto"; +import "udpa/annotations/versioning.proto"; + +option java_package = "io.envoyproxy.envoy.service.auth.v3"; +option java_outer_classname = "ExternalAuthProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3;authv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Authorization service] + +// The authorization service request messages used by external authorization :ref:`network filter +// ` and :ref:`HTTP filter `. + +// A generic interface for performing authorization check on incoming +// requests to a networked service. +service Authorization { + // Performs authorization check based on the attributes associated with the + // incoming request, and returns status `OK` or not `OK`. + rpc Check(CheckRequest) returns (CheckResponse) { + } +} + +message CheckRequest { + option (udpa.annotations.versioning).previous_message_type = "envoy.service.auth.v2.CheckRequest"; + + // The request attributes. + AttributeContext attributes = 1; +} + +// HTTP attributes for a denied response. +message DeniedHttpResponse { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.DeniedHttpResponse"; + + // This field allows the authorization service to send an HTTP response status code to the + // downstream client. If not set, Envoy sends ``403 Forbidden`` HTTP status code by default. + type.v3.HttpStatus status = 1; + + // This field allows the authorization service to send HTTP response headers + // to the downstream client. Note that the :ref:`append field in HeaderValueOption ` defaults to + // false when used in this message. + repeated config.core.v3.HeaderValueOption headers = 2; + + // This field allows the authorization service to send a response body data + // to the downstream client. + string body = 3; +} + +// HTTP attributes for an OK response. +// [#next-free-field: 9] +message OkHttpResponse { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.OkHttpResponse"; + + // HTTP entity headers in addition to the original request headers. This allows the authorization + // service to append, to add or to override headers from the original request before + // dispatching it to the upstream. Note that the :ref:`append field in HeaderValueOption ` defaults to + // false when used in this message. By setting the ``append`` field to ``true``, + // the filter will append the correspondent header value to the matched request header. + // By leaving ``append`` as false, the filter will either add a new header, or override an existing + // one if there is a match. + repeated config.core.v3.HeaderValueOption headers = 2; + + // HTTP entity headers to remove from the original request before dispatching + // it to the upstream. This allows the authorization service to act on auth + // related headers (like ``Authorization``), process them, and consume them. + // Under this model, the upstream will either receive the request (if it's + // authorized) or not receive it (if it's not), but will not see headers + // containing authorization credentials. + // + // Pseudo headers (such as ``:authority``, ``:method``, ``:path`` etc), as well as + // the header ``Host``, may not be removed as that would make the request + // malformed. If mentioned in ``headers_to_remove`` these special headers will + // be ignored. + // + // When using the HTTP service this must instead be set by the HTTP + // authorization service as a comma separated list like so: + // ``x-envoy-auth-headers-to-remove: one-auth-header, another-auth-header``. + repeated string headers_to_remove = 5; + + // This field has been deprecated in favor of :ref:`CheckResponse.dynamic_metadata + // `. Until it is removed, + // setting this field overrides :ref:`CheckResponse.dynamic_metadata + // `. + google.protobuf.Struct dynamic_metadata = 3 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // This field allows the authorization service to send HTTP response headers + // to the downstream client on success. Note that the :ref:`append field in HeaderValueOption ` + // defaults to false when used in this message. + repeated config.core.v3.HeaderValueOption response_headers_to_add = 6; + + // This field allows the authorization service to set (and overwrite) query + // string parameters on the original request before it is sent upstream. + repeated config.core.v3.QueryParameter query_parameters_to_set = 7; + + // This field allows the authorization service to specify which query parameters + // should be removed from the original request before it is sent upstream. Each + // element in this list is a case-sensitive query parameter name to be removed. + repeated string query_parameters_to_remove = 8; +} + +// Intended for gRPC and Network Authorization servers ``only``. +// [#next-free-field: 6] +message CheckResponse { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.CheckResponse"; + + // Status ``OK`` allows the request. Any other status indicates the request should be denied, and + // for HTTP filter, if not overridden by :ref:`denied HTTP response status ` + // Envoy sends ``403 Forbidden`` HTTP status code by default. + google.rpc.Status status = 1; + + // An message that contains HTTP response attributes. This message is + // used when the authorization service needs to send custom responses to the + // downstream client or, to modify/add request headers being dispatched to the upstream. + oneof http_response { + // Supplies http attributes for a denied response. + DeniedHttpResponse denied_response = 2; + + // Supplies http attributes for an ok response. + OkHttpResponse ok_response = 3; + + // Supplies http attributes for an error response. This is used when the authorization + // service encounters an internal error and wants to return custom headers and body to the + // downstream client. When ``error_response`` is set, the ext_authz filter increments the + // ``ext_authz_error`` stat and respects the :ref:`failure_mode_allow + // ` + // configuration. The HTTP status code, headers, and body are taken from the + // :ref:`DeniedHttpResponse ` message. + // If the status field is not set, Envoy sends the status code configured via + // :ref:`status_on_error `, + // which defaults to ``403 Forbidden``. + DeniedHttpResponse error_response = 5; + } + + // Optional response metadata that will be emitted as dynamic metadata to be consumed by the next + // filter. This metadata lives in a namespace specified by the canonical name of extension filter + // that requires it: + // + // - :ref:`envoy.filters.http.ext_authz ` for HTTP filter. + // - :ref:`envoy.filters.network.ext_authz ` for network filter. + google.protobuf.Struct dynamic_metadata = 4; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/discovery.proto b/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/discovery.proto index b7270f246de..e1ce827a48f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/discovery.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/discovery.proto @@ -41,18 +41,29 @@ message ResourceName { DynamicParameterConstraints dynamic_parameter_constraints = 2; } +// [#not-implemented-hide:] +// An error associated with a specific resource name, returned to the +// client by the server. +message ResourceError { + // The name of the resource. + ResourceName resource_name = 1; + + // The error reported for the resource. + google.rpc.Status error_detail = 2; +} + // A DiscoveryRequest requests a set of versioned resources of the same type for // a given Envoy node on some API. // [#next-free-field: 8] message DiscoveryRequest { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.DiscoveryRequest"; - // The version_info provided in the request messages will be the version_info + // The ``version_info`` provided in the request messages will be the ``version_info`` // received with the most recent successfully processed response or empty on // the first request. It is expected that no new request is sent after a // response is received until the Envoy instance is ready to ACK/NACK the new // configuration. ACK/NACK takes place by returning the new API config version - // as applied or the previous API config version respectively. Each type_url + // as applied or the previous API config version respectively. Each ``type_url`` // (see below) has an independent version associated with it. string version_info = 1; @@ -61,10 +72,10 @@ message DiscoveryRequest { // List of resources to subscribe to, e.g. list of cluster names or a route // configuration name. If this is empty, all resources for the API are - // returned. LDS/CDS may have empty resource_names, which will cause all + // returned. LDS/CDS may have empty ``resource_names``, which will cause all // resources for the Envoy instance to be returned. The LDS and CDS responses // will then imply a number of resources that need to be fetched via EDS/RDS, - // which will be explicitly enumerated in resource_names. + // which will be explicitly enumerated in ``resource_names``. repeated string resource_names = 3; // [#not-implemented-hide:] @@ -72,21 +83,27 @@ message DiscoveryRequest { // parameters along with each resource name. Clients that populate this // field must be able to handle responses from the server where resources // are wrapped in a Resource message. - // Note that it is legal for a request to have some resources listed - // in ``resource_names`` and others in ``resource_locators``. + // + // .. note:: + // It is legal for a request to have some resources listed + // in ``resource_names`` and others in ``resource_locators``. + // repeated ResourceLocator resource_locators = 7; // Type of the resource that is being requested, e.g. - // "type.googleapis.com/envoy.api.v2.ClusterLoadAssignment". This is implicit + // ``type.googleapis.com/envoy.api.v2.ClusterLoadAssignment``. This is implicit // in requests made via singleton xDS APIs such as CDS, LDS, etc. but is // required for ADS. string type_url = 4; - // nonce corresponding to DiscoveryResponse being ACK/NACKed. See above - // discussion on version_info and the DiscoveryResponse nonce comment. This - // may be empty only if 1) this is a non-persistent-stream xDS such as HTTP, - // or 2) the client has not yet accepted an update in this xDS stream (unlike - // delta, where it is populated only for new explicit ACKs). + // nonce corresponding to ``DiscoveryResponse`` being ACK/NACKed. See above + // discussion on ``version_info`` and the ``DiscoveryResponse`` nonce comment. This + // may be empty only if: + // + // * This is a non-persistent-stream xDS such as HTTP, or + // * The client has not yet accepted an update in this xDS stream (unlike + // delta, where it is populated only for new explicit ACKs). + // string response_nonce = 5; // This is populated when the previous :ref:`DiscoveryResponse ` @@ -96,7 +113,7 @@ message DiscoveryRequest { google.rpc.Status error_detail = 6; } -// [#next-free-field: 7] +// [#next-free-field: 8] message DiscoveryResponse { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.DiscoveryResponse"; @@ -109,35 +126,46 @@ message DiscoveryResponse { // [#not-implemented-hide:] // Canary is used to support two Envoy command line flags: // - // * --terminate-on-canary-transition-failure. When set, Envoy is able to + // * ``--terminate-on-canary-transition-failure``. When set, Envoy is able to // terminate if it detects that configuration is stuck at canary. Consider // this example sequence of updates: - // - Management server applies a canary config successfully. - // - Management server rolls back to a production config. - // - Envoy rejects the new production config. + // + // * Management server applies a canary config successfully. + // * Management server rolls back to a production config. + // * Envoy rejects the new production config. + // // Since there is no sensible way to continue receiving configuration // updates, Envoy will then terminate and apply production config from a // clean slate. - // * --dry-run-canary. When set, a canary response will never be applied, only + // + // * ``--dry-run-canary``. When set, a canary response will never be applied, only // validated via a dry run. + // bool canary = 3; // Type URL for resources. Identifies the xDS API when muxing over ADS. - // Must be consistent with the type_url in the 'resources' repeated Any (if non-empty). + // Must be consistent with the ``type_url`` in the 'resources' repeated Any (if non-empty). string type_url = 4; // For gRPC based subscriptions, the nonce provides a way to explicitly ack a - // specific DiscoveryResponse in a following DiscoveryRequest. Additional + // specific ``DiscoveryResponse`` in a following ``DiscoveryRequest``. Additional // messages may have been sent by Envoy to the management server for the - // previous version on the stream prior to this DiscoveryResponse, that were + // previous version on the stream prior to this ``DiscoveryResponse``, that were // unprocessed at response send time. The nonce allows the management server - // to ignore any further DiscoveryRequests for the previous version until a - // DiscoveryRequest bearing the nonce. The nonce is optional and is not + // to ignore any further ``DiscoveryRequests`` for the previous version until a + // ``DiscoveryRequest`` bearing the nonce. The nonce is optional and is not // required for non-stream based xDS implementations. string nonce = 5; // The control plane instance that sent the response. config.core.v3.ControlPlane control_plane = 6; + + // [#not-implemented-hide:] + // Errors associated with specific resources. Clients are expected to + // remember the most recent error for a given resource across responses; + // the error condition is not considered to be cleared until a response is + // received that contains the resource in the 'resources' field. + repeated ResourceError resource_errors = 7; } // DeltaDiscoveryRequest and DeltaDiscoveryResponse are used in a new gRPC @@ -153,25 +181,28 @@ message DiscoveryResponse { // connected to it. // // In Delta xDS the nonce field is required and used to pair -// DeltaDiscoveryResponse to a DeltaDiscoveryRequest ACK or NACK. -// Optionally, a response message level system_version_info is present for +// ``DeltaDiscoveryResponse`` to a ``DeltaDiscoveryRequest`` ACK or NACK. +// Optionally, a response message level ``system_version_info`` is present for // debugging purposes only. // -// DeltaDiscoveryRequest plays two independent roles. Any DeltaDiscoveryRequest -// can be either or both of: [1] informing the server of what resources the -// client has gained/lost interest in (using resource_names_subscribe and -// resource_names_unsubscribe), or [2] (N)ACKing an earlier resource update from -// the server (using response_nonce, with presence of error_detail making it a NACK). -// Additionally, the first message (for a given type_url) of a reconnected gRPC stream +// ``DeltaDiscoveryRequest`` plays two independent roles. Any ``DeltaDiscoveryRequest`` +// can be either or both of: +// +// * Informing the server of what resources the client has gained/lost interest in +// (using ``resource_names_subscribe`` and ``resource_names_unsubscribe``), or +// * (N)ACKing an earlier resource update from the server (using ``response_nonce``, +// with presence of ``error_detail`` making it a NACK). +// +// Additionally, the first message (for a given ``type_url``) of a reconnected gRPC stream // has a third role: informing the server of the resources (and their versions) -// that the client already possesses, using the initial_resource_versions field. +// that the client already possesses, using the ``initial_resource_versions`` field. // // As with state-of-the-world, when multiple resource types are multiplexed (ADS), -// all requests/acknowledgments/updates are logically walled off by type_url: +// all requests/acknowledgments/updates are logically walled off by ``type_url``: // a Cluster ACK exists in a completely separate world from a prior Route NACK. -// In particular, initial_resource_versions being sent at the "start" of every -// gRPC stream actually entails a message for each type_url, each with its own -// initial_resource_versions. +// In particular, ``initial_resource_versions`` being sent at the "start" of every +// gRPC stream actually entails a message for each ``type_url``, each with its own +// ``initial_resource_versions``. // [#next-free-field: 10] message DeltaDiscoveryRequest { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.DeltaDiscoveryRequest"; @@ -187,23 +218,24 @@ message DeltaDiscoveryRequest { // DeltaDiscoveryRequests allow the client to add or remove individual // resources to the set of tracked resources in the context of a stream. - // All resource names in the resource_names_subscribe list are added to the - // set of tracked resources and all resource names in the resource_names_unsubscribe + // All resource names in the ``resource_names_subscribe`` list are added to the + // set of tracked resources and all resource names in the ``resource_names_unsubscribe`` // list are removed from the set of tracked resources. // - // *Unlike* state-of-the-world xDS, an empty resource_names_subscribe or - // resource_names_unsubscribe list simply means that no resources are to be + // *Unlike* state-of-the-world xDS, an empty ``resource_names_subscribe`` or + // ``resource_names_unsubscribe`` list simply means that no resources are to be // added or removed to the resource list. // *Like* state-of-the-world xDS, the server must send updates for all tracked // resources, but can also send updates for resources the client has not subscribed to. // - // NOTE: the server must respond with all resources listed in resource_names_subscribe, - // even if it believes the client has the most recent version of them. The reason: - // the client may have dropped them, but then regained interest before it had a chance - // to send the unsubscribe message. See DeltaSubscriptionStateTest.RemoveThenAdd. + // .. note:: + // The server must respond with all resources listed in ``resource_names_subscribe``, + // even if it believes the client has the most recent version of them. The reason: + // the client may have dropped them, but then regained interest before it had a chance + // to send the unsubscribe message. See DeltaSubscriptionStateTest.RemoveThenAdd. // - // These two fields can be set in any DeltaDiscoveryRequest, including ACKs - // and initial_resource_versions. + // These two fields can be set in any ``DeltaDiscoveryRequest``, including ACKs + // and ``initial_resource_versions``. // // A list of Resource names to add to the list of tracked resources. repeated string resource_names_subscribe = 3; @@ -214,31 +246,40 @@ message DeltaDiscoveryRequest { // [#not-implemented-hide:] // Alternative to ``resource_names_subscribe`` field that allows specifying dynamic parameters // along with each resource name. - // Note that it is legal for a request to have some resources listed - // in ``resource_names_subscribe`` and others in ``resource_locators_subscribe``. + // + // .. note:: + // It is legal for a request to have some resources listed + // in ``resource_names_subscribe`` and others in ``resource_locators_subscribe``. + // repeated ResourceLocator resource_locators_subscribe = 8; // [#not-implemented-hide:] // Alternative to ``resource_names_unsubscribe`` field that allows specifying dynamic parameters // along with each resource name. - // Note that it is legal for a request to have some resources listed - // in ``resource_names_unsubscribe`` and others in ``resource_locators_unsubscribe``. + // + // .. note:: + // It is legal for a request to have some resources listed + // in ``resource_names_unsubscribe`` and others in ``resource_locators_unsubscribe``. + // repeated ResourceLocator resource_locators_unsubscribe = 9; // Informs the server of the versions of the resources the xDS client knows of, to enable the // client to continue the same logical xDS session even in the face of gRPC stream reconnection. - // It will not be populated: [1] in the very first stream of a session, since the client will - // not yet have any resources, [2] in any message after the first in a stream (for a given - // type_url), since the server will already be correctly tracking the client's state. - // (In ADS, the first message *of each type_url* of a reconnected stream populates this map.) + // It will not be populated: + // + // * In the very first stream of a session, since the client will not yet have any resources. + // * In any message after the first in a stream (for a given ``type_url``), since the server will + // already be correctly tracking the client's state. + // + // (In ADS, the first message ``of each type_url`` of a reconnected stream populates this map.) // The map's keys are names of xDS resources known to the xDS client. // The map's values are opaque resource versions. map initial_resource_versions = 5; - // When the DeltaDiscoveryRequest is a ACK or NACK message in response - // to a previous DeltaDiscoveryResponse, the response_nonce must be the - // nonce in the DeltaDiscoveryResponse. - // Otherwise (unlike in DiscoveryRequest) response_nonce must be omitted. + // When the ``DeltaDiscoveryRequest`` is a ACK or NACK message in response + // to a previous ``DeltaDiscoveryResponse``, the ``response_nonce`` must be the + // nonce in the ``DeltaDiscoveryResponse``. + // Otherwise (unlike in ``DiscoveryRequest``) ``response_nonce`` must be omitted. string response_nonce = 6; // This is populated when the previous :ref:`DiscoveryResponse ` @@ -247,7 +288,7 @@ message DeltaDiscoveryRequest { google.rpc.Status error_detail = 7; } -// [#next-free-field: 9] +// [#next-free-field: 10] message DeltaDiscoveryResponse { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.DeltaDiscoveryResponse"; @@ -256,37 +297,46 @@ message DeltaDiscoveryResponse { string system_version_info = 1; // The response resources. These are typed resources, whose types must match - // the type_url field. + // the ``type_url`` field. repeated Resource resources = 2; // field id 3 IS available! // Type URL for resources. Identifies the xDS API when muxing over ADS. - // Must be consistent with the type_url in the Any within 'resources' if 'resources' is non-empty. + // Must be consistent with the ``type_url`` in the Any within 'resources' if 'resources' is non-empty. string type_url = 4; - // Resources names of resources that have be deleted and to be removed from the xDS Client. + // Resource names of resources that have been deleted and to be removed from the xDS Client. // Removed resources for missing resources can be ignored. repeated string removed_resources = 6; - // Alternative to removed_resources that allows specifying which variant of + // Alternative to ``removed_resources`` that allows specifying which variant of // a resource is being removed. This variant must be used for any resource // for which dynamic parameter constraints were sent to the client. repeated ResourceName removed_resource_names = 8; - // The nonce provides a way for DeltaDiscoveryRequests to uniquely - // reference a DeltaDiscoveryResponse when (N)ACKing. The nonce is required. + // The nonce provides a way for ``DeltaDiscoveryRequests`` to uniquely + // reference a ``DeltaDiscoveryResponse`` when (N)ACKing. The nonce is required. string nonce = 5; // [#not-implemented-hide:] // The control plane instance that sent the response. config.core.v3.ControlPlane control_plane = 7; + + // [#not-implemented-hide:] + // Errors associated with specific resources. + // + // .. note:: + // A resource in this field with a status of NOT_FOUND should be treated the same as + // a resource listed in the ``removed_resources`` or ``removed_resource_names`` fields. + // + repeated ResourceError resource_errors = 9; } // A set of dynamic parameter constraints associated with a variant of an individual xDS resource. // These constraints determine whether the resource matches a subscription based on the set of // dynamic parameters in the subscription, as specified in the -// :ref:`ResourceLocator.dynamic_parameters` +// :ref:`ResourceLocator.dynamic_parameters ` // field. This allows xDS implementations (clients, servers, and caching proxies) to determine // which variant of a resource is appropriate for a given client. message DynamicParameterConstraints { @@ -340,8 +390,11 @@ message Resource { // [#not-implemented-hide:] message CacheControl { // If true, xDS proxies may not cache this resource. - // Note that this does not apply to clients other than xDS proxies, which must cache resources - // for their own use, regardless of the value of this field. + // + // .. note:: + // This does not apply to clients other than xDS proxies, which must cache resources + // for their own use, regardless of the value of this field. + // bool do_not_cache = 1; } @@ -371,7 +424,7 @@ message Resource { // configuration for the resource will be removed. // // The TTL can be refreshed or changed by sending a response that doesn't change the resource - // version. In this case the resource field does not need to be populated, which allows for + // version. In this case the ``resource`` field does not need to be populated, which allows for // light-weight "heartbeat" updates to keep a resource with a TTL alive. // // The TTL feature is meant to support configurations that should be removed in the event of diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/status/v3/csds.proto b/xds/third_party/envoy/src/main/proto/envoy/service/status/v3/csds.proto index 1c51f2bac37..de62fbf9b0f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/service/status/v3/csds.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/service/status/v3/csds.proto @@ -72,6 +72,11 @@ enum ClientConfigStatus { // config dump is not the NACKed version, but the most recent accepted one. If // no config is accepted yet, the attached config dump will be empty. CLIENT_NACKED = 3; + + // Client received an error from the control plane. The attached config + // dump is the most recent accepted one. If no config is accepted yet, + // the attached config dump will be empty. + CLIENT_RECEIVED_ERROR = 4; } // Request for client status of clients identified by a list of NodeMatchers. diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/address.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/address.proto new file mode 100644 index 00000000000..8a03a5320af --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/address.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; + +package envoy.type.matcher.v3; + +import "xds/core/v3/cidr.proto"; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.type.matcher.v3"; +option java_outer_classname = "AddressProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3;matcherv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Address Matcher] + +// Match an IP against a repeated CIDR range. This matcher is intended to be +// used in other matchers, for example in the filter state matcher to match a +// filter state object as an IP. +message AddressMatcher { + repeated xds.core.v3.CidrRange ranges = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/filter_state.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/filter_state.proto index f813178ae05..8c38a515ae9 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/filter_state.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/filter_state.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package envoy.type.matcher.v3; +import "envoy/type/matcher/v3/address.proto"; import "envoy/type/matcher/v3/string.proto"; import "udpa/annotations/status.proto"; @@ -25,5 +26,8 @@ message FilterStateMatcher { // Matches the filter state object as a string value. StringMatcher string_match = 2; + + // Matches the filter state object as a ip Instance. + AddressMatcher address_match = 3; } } diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/metadata.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/metadata.proto index d3316e88a88..30abde97c09 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/metadata.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/metadata.proto @@ -16,11 +16,11 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Metadata matcher] -// MetadataMatcher provides a general interface to check if a given value is matched in -// :ref:`Metadata `. It uses `filter` and `path` to retrieve the value -// from the Metadata and then check if it's matched to the specified value. +// ``MetadataMatcher`` provides a general interface to check if a given value is matched in +// :ref:`Metadata `. It uses ``filter`` and ``path`` to retrieve the value +// from the ``Metadata`` and then check if it's matched to the specified value. // -// For example, for the following Metadata: +// For example, for the following ``Metadata``: // // .. code-block:: yaml // @@ -41,8 +41,8 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // - string_value: m // - string_value: n // -// The following MetadataMatcher is matched as the path [a, b, c] will retrieve a string value "pro" -// from the Metadata which is matched to the specified prefix match. +// The following ``MetadataMatcher`` is matched as the path ``[a, b, c]`` will retrieve a string value ``pro`` +// from the ``Metadata`` which is matched to the specified prefix match. // // .. code-block:: yaml // @@ -55,7 +55,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // string_match: // prefix: pr // -// The following MetadataMatcher is matched as the code will match one of the string values in the +// The following ``MetadataMatcher`` is matched as the code will match one of the string values in the // list at the path [a, t]. // // .. code-block:: yaml @@ -70,7 +70,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // string_match: // exact: m // -// An example use of MetadataMatcher is specifying additional metadata in envoy.filters.http.rbac to +// An example use of ``MetadataMatcher`` is specifying additional metadata in ``envoy.filters.http.rbac`` to // enforce access control based on dynamic metadata in a request. See :ref:`Permission // ` and :ref:`Principal // `. @@ -79,9 +79,11 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; message MetadataMatcher { option (udpa.annotations.versioning).previous_message_type = "envoy.type.matcher.MetadataMatcher"; - // Specifies the segment in a path to retrieve value from Metadata. - // Note: Currently it's not supported to retrieve a value from a list in Metadata. This means that - // if the segment key refers to a list, it has to be the last segment in a path. + // Specifies the segment in a path to retrieve value from ``Metadata``. + // + // .. note:: + // Currently it's not supported to retrieve a value from a list in ``Metadata``. This means that + // if the segment key refers to a list, it has to be the last segment in a path. message PathSegment { option (udpa.annotations.versioning).previous_message_type = "envoy.type.matcher.MetadataMatcher.PathSegment"; @@ -89,18 +91,18 @@ message MetadataMatcher { oneof segment { option (validate.required) = true; - // If specified, use the key to retrieve the value in a Struct. + // If specified, use the key to retrieve the value in a ``Struct``. string key = 1 [(validate.rules).string = {min_len: 1}]; } } - // The filter name to retrieve the Struct from the Metadata. + // The filter name to retrieve the ``Struct`` from the ``Metadata``. string filter = 1 [(validate.rules).string = {min_len: 1}]; - // The path to retrieve the Value from the Struct. + // The path to retrieve the ``Value`` from the ``Struct``. repeated PathSegment path = 2 [(validate.rules).repeated = {min_items: 1}]; - // The MetadataMatcher is matched if the value retrieved by path is matched to this value. + // The ``MetadataMatcher`` is matched if the value retrieved by path is matched to this value. ValueMatcher value = 3 [(validate.rules).message = {required: true}]; // If true, the match result will be inverted. diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto index 10033749acd..56d39565ca5 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto @@ -38,7 +38,10 @@ message StringMatcher { string exact = 1; // The input string must have the prefix specified here. - // Note: empty prefix is not allowed, please use regex instead. + // + // .. note:: + // + // Empty prefix match is not allowed, please use ``safe_regex`` instead. // // Examples: // @@ -46,7 +49,10 @@ message StringMatcher { string prefix = 2 [(validate.rules).string = {min_len: 1}]; // The input string must have the suffix specified here. - // Note: empty prefix is not allowed, please use regex instead. + // + // .. note:: + // + // Empty suffix match is not allowed, please use ``safe_regex`` instead. // // Examples: // @@ -57,7 +63,10 @@ message StringMatcher { RegexMatcher safe_regex = 5 [(validate.rules).message = {required: true}]; // The input string must have the substring specified here. - // Note: empty contains match is not allowed, please use regex instead. + // + // .. note:: + // + // Empty contains match is not allowed, please use ``safe_regex`` instead. // // Examples: // @@ -69,9 +78,10 @@ message StringMatcher { xds.core.v3.TypedExtensionConfig custom = 8; } - // If true, indicates the exact/prefix/suffix/contains matching should be case insensitive. This - // has no effect for the safe_regex match. - // For example, the matcher ``data`` will match both input string ``Data`` and ``data`` if set to true. + // If ``true``, indicates the exact/prefix/suffix/contains matching should be case insensitive. This + // has no effect for the ``safe_regex`` match. + // For example, the matcher ``data`` will match both input string ``Data`` and ``data`` if this option + // is set to ``true``. bool ignore_case = 6; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/value.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/value.proto index d773c6057fc..8d65c457ccc 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/value.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/value.proto @@ -17,7 +17,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Value matcher] -// Specifies the way to match a ProtobufWkt::Value. Primitive values and ListValue are supported. +// Specifies the way to match a Protobuf::Value. Primitive values and ListValue are supported. // StructValue is not supported and is always not matched. // [#next-free-field: 8] message ValueMatcher { diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v3/metadata.proto b/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v3/metadata.proto index 20758577503..d131635bf9f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v3/metadata.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v3/metadata.proto @@ -14,10 +14,10 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Metadata] -// MetadataKey provides a general interface using ``key`` and ``path`` to retrieve value from -// :ref:`Metadata `. +// MetadataKey provides a way to retrieve values from +// :ref:`Metadata ` using a ``key`` and a ``path``. // -// For example, for the following Metadata: +// For example, consider the following Metadata: // // .. code-block:: yaml // @@ -28,7 +28,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // xyz: // hello: envoy // -// The following MetadataKey will retrieve a string value "bar" from the Metadata. +// The following MetadataKey would retrieve the string value "bar" from the Metadata: // // .. code-block:: yaml // @@ -40,8 +40,8 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; message MetadataKey { option (udpa.annotations.versioning).previous_message_type = "envoy.type.metadata.v2.MetadataKey"; - // Specifies the segment in a path to retrieve value from Metadata. - // Currently it is only supported to specify the key, i.e. field name, as one segment of a path. + // Specifies a segment in a path for retrieving values from Metadata. + // Currently, only key-based segments (field names) are supported. message PathSegment { option (udpa.annotations.versioning).previous_message_type = "envoy.type.metadata.v2.MetadataKey.PathSegment"; @@ -49,25 +49,27 @@ message MetadataKey { oneof segment { option (validate.required) = true; - // If specified, use the key to retrieve the value in a Struct. + // If specified, use this key to retrieve the value in a Struct. string key = 1 [(validate.rules).string = {min_len: 1}]; } } - // The key name of Metadata to retrieve the Struct from the metadata. - // Typically, it represents a builtin subsystem or custom extension. + // The key name of the Metadata from which to retrieve the Struct. + // This typically represents a builtin subsystem or custom extension. string key = 1 [(validate.rules).string = {min_len: 1}]; - // The path to retrieve the Value from the Struct. It can be a prefix or a full path, - // e.g. ``[prop, xyz]`` for a struct or ``[prop, foo]`` for a string in the example, - // which depends on the particular scenario. + // The path used to retrieve a specific Value from the Struct. + // This can be either a prefix or a full path, depending on the use case. + // For example, ``[prop, xyz]`` would retrieve a struct or ``[prop, foo]`` would retrieve a string + // in the example above. // - // Note: Due to that only the key type segment is supported, the path can not specify a list - // unless the list is the last segment. + // .. note:: + // Since only key-type segments are supported, a path cannot specify a list + // unless the list is the last segment. repeated PathSegment path = 2 [(validate.rules).repeated = {min_items: 1}]; } -// Describes what kind of metadata. +// Describes different types of metadata sources. message MetadataKind { option (udpa.annotations.versioning).previous_message_type = "envoy.type.metadata.v2.MetadataKind"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v3/custom_tag.proto b/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v3/custom_tag.proto index feb57e8eb66..cdb42a43507 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v3/custom_tag.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v3/custom_tag.proto @@ -17,7 +17,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Custom Tag] // Describes custom tags for the active span. -// [#next-free-field: 6] +// [#next-free-field: 7] message CustomTag { option (udpa.annotations.versioning).previous_message_type = "envoy.type.tracing.v2.CustomTag"; @@ -98,5 +98,12 @@ message CustomTag { // A custom tag to obtain tag value from the metadata. Metadata metadata = 5; + + // Custom tag value. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + string value = 6; } } diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/v3/http_status.proto b/xds/third_party/envoy/src/main/proto/envoy/type/v3/http_status.proto index ab03e1b2b72..40d697beefc 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/v3/http_status.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/v3/http_status.proto @@ -21,116 +21,172 @@ enum StatusCode { // `enum` type. Empty = 0; + // Continue - ``100`` status code. Continue = 100; + // OK - ``200`` status code. OK = 200; + // Created - ``201`` status code. Created = 201; + // Accepted - ``202`` status code. Accepted = 202; + // NonAuthoritativeInformation - ``203`` status code. NonAuthoritativeInformation = 203; + // NoContent - ``204`` status code. NoContent = 204; + // ResetContent - ``205`` status code. ResetContent = 205; + // PartialContent - ``206`` status code. PartialContent = 206; + // MultiStatus - ``207`` status code. MultiStatus = 207; + // AlreadyReported - ``208`` status code. AlreadyReported = 208; + // IMUsed - ``226`` status code. IMUsed = 226; + // MultipleChoices - ``300`` status code. MultipleChoices = 300; + // MovedPermanently - ``301`` status code. MovedPermanently = 301; + // Found - ``302`` status code. Found = 302; + // SeeOther - ``303`` status code. SeeOther = 303; + // NotModified - ``304`` status code. NotModified = 304; + // UseProxy - ``305`` status code. UseProxy = 305; + // TemporaryRedirect - ``307`` status code. TemporaryRedirect = 307; + // PermanentRedirect - ``308`` status code. PermanentRedirect = 308; + // BadRequest - ``400`` status code. BadRequest = 400; + // Unauthorized - ``401`` status code. Unauthorized = 401; + // PaymentRequired - ``402`` status code. PaymentRequired = 402; + // Forbidden - ``403`` status code. Forbidden = 403; + // NotFound - ``404`` status code. NotFound = 404; + // MethodNotAllowed - ``405`` status code. MethodNotAllowed = 405; + // NotAcceptable - ``406`` status code. NotAcceptable = 406; + // ProxyAuthenticationRequired - ``407`` status code. ProxyAuthenticationRequired = 407; + // RequestTimeout - ``408`` status code. RequestTimeout = 408; + // Conflict - ``409`` status code. Conflict = 409; + // Gone - ``410`` status code. Gone = 410; + // LengthRequired - ``411`` status code. LengthRequired = 411; + // PreconditionFailed - ``412`` status code. PreconditionFailed = 412; + // PayloadTooLarge - ``413`` status code. PayloadTooLarge = 413; + // URITooLong - ``414`` status code. URITooLong = 414; + // UnsupportedMediaType - ``415`` status code. UnsupportedMediaType = 415; + // RangeNotSatisfiable - ``416`` status code. RangeNotSatisfiable = 416; + // ExpectationFailed - ``417`` status code. ExpectationFailed = 417; + // MisdirectedRequest - ``421`` status code. MisdirectedRequest = 421; + // UnprocessableEntity - ``422`` status code. UnprocessableEntity = 422; + // Locked - ``423`` status code. Locked = 423; + // FailedDependency - ``424`` status code. FailedDependency = 424; + // UpgradeRequired - ``426`` status code. UpgradeRequired = 426; + // PreconditionRequired - ``428`` status code. PreconditionRequired = 428; + // TooManyRequests - ``429`` status code. TooManyRequests = 429; + // RequestHeaderFieldsTooLarge - ``431`` status code. RequestHeaderFieldsTooLarge = 431; + // InternalServerError - ``500`` status code. InternalServerError = 500; + // NotImplemented - ``501`` status code. NotImplemented = 501; + // BadGateway - ``502`` status code. BadGateway = 502; + // ServiceUnavailable - ``503`` status code. ServiceUnavailable = 503; + // GatewayTimeout - ``504`` status code. GatewayTimeout = 504; + // HTTPVersionNotSupported - ``505`` status code. HTTPVersionNotSupported = 505; + // VariantAlsoNegotiates - ``506`` status code. VariantAlsoNegotiates = 506; + // InsufficientStorage - ``507`` status code. InsufficientStorage = 507; + // LoopDetected - ``508`` status code. LoopDetected = 508; + // NotExtended - ``510`` status code. NotExtended = 510; + // NetworkAuthenticationRequired - ``511`` status code. NetworkAuthenticationRequired = 511; } diff --git a/xds/third_party/xds/import.sh b/xds/third_party/xds/import.sh index 9e4bf71d52f..7af5c8489d1 100755 --- a/xds/third_party/xds/import.sh +++ b/xds/third_party/xds/import.sh @@ -17,7 +17,7 @@ set -e # import VERSION from one of the google internal CLs -VERSION=024c85f92f20cab567a83acc50934c7f9711d124 +VERSION=2ac532fd44436293585084f8d94c6bdb17835af0 DOWNLOAD_URL="https://github.com/cncf/xds/archive/${VERSION}.tar.gz" DOWNLOAD_BASE_DIR="xds-${VERSION}" SOURCE_PROTO_BASE_DIR="${DOWNLOAD_BASE_DIR}" @@ -40,6 +40,7 @@ xds/annotations/v3/versioning.proto xds/core/v3/authority.proto xds/core/v3/collection_entry.proto xds/core/v3/context_params.proto +xds/core/v3/cidr.proto xds/core/v3/extension.proto xds/core/v3/resource_locator.proto xds/core/v3/resource_name.proto diff --git a/xds/third_party/xds/src/main/proto/xds/core/v3/cidr.proto b/xds/third_party/xds/src/main/proto/xds/core/v3/cidr.proto new file mode 100644 index 00000000000..b8471bc8078 --- /dev/null +++ b/xds/third_party/xds/src/main/proto/xds/core/v3/cidr.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package xds.core.v3; + +import "xds/annotations/v3/status.proto"; +import "google/protobuf/wrappers.proto"; + +import "validate/validate.proto"; + +option java_outer_classname = "CidrRangeProto"; +option java_multiple_files = true; +option java_package = "com.github.xds.core.v3"; +option go_package = "github.com/cncf/xds/go/xds/core/v3"; + +option (xds.annotations.v3.file_status).work_in_progress = true; + +// CidrRange specifies an IP Address and a prefix length to construct +// the subnet mask for a `CIDR `_ range. +message CidrRange { + // IPv4 or IPv6 address, e.g. ``192.0.0.0`` or ``2001:db8::``. + string address_prefix = 1 [(validate.rules).string = {min_len: 1}]; + + // Length of prefix, e.g. 0, 32. Defaults to 0 when unset. + google.protobuf.UInt32Value prefix_len = 2 [(validate.rules).uint32 = {lte: 128}]; +} \ No newline at end of file diff --git a/xds/third_party/xds/src/main/proto/xds/data/orca/v3/orca_load_report.proto b/xds/third_party/xds/src/main/proto/xds/data/orca/v3/orca_load_report.proto index 53da75f78ac..1b0847585a4 100644 --- a/xds/third_party/xds/src/main/proto/xds/data/orca/v3/orca_load_report.proto +++ b/xds/third_party/xds/src/main/proto/xds/data/orca/v3/orca_load_report.proto @@ -10,7 +10,7 @@ option go_package = "github.com/cncf/xds/go/xds/data/orca/v3"; import "validate/validate.proto"; // See section `ORCA load report format` of the design document in -// :ref:`https://github.com/envoyproxy/envoy/issues/6614`. +// https://github.com/envoyproxy/envoy/issues/6614. message OrcaLoadReport { // CPU utilization expressed as a fraction of available CPU resources. This diff --git a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/cel.proto b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/cel.proto index b1ad1faa281..a45af9534a0 100644 --- a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/cel.proto +++ b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/cel.proto @@ -2,9 +2,7 @@ syntax = "proto3"; package xds.type.matcher.v3; -import "xds/annotations/v3/status.proto"; import "xds/type/v3/cel.proto"; - import "validate/validate.proto"; option java_package = "com.github.xds.type.matcher.v3"; @@ -12,8 +10,6 @@ option java_outer_classname = "CelProto"; option java_multiple_files = true; option go_package = "github.com/cncf/xds/go/xds/type/matcher/v3"; -option (xds.annotations.v3.file_status).work_in_progress = true; - // [#protodoc-title: Common Expression Language (CEL) matchers] // Performs a match by evaluating a `Common Expression Language @@ -24,14 +20,13 @@ option (xds.annotations.v3.file_status).work_in_progress = true; // // The match is ``true``, iff the result of the evaluation is a bool AND true. // In all other cases, the match is ``false``, including but not limited to: non-bool types, -// ``false``, ``null``,`` int(1)``, etc. +// ``false``, ``null``, ``int(1)``, etc. // In case CEL expression raises an error, the result of the evaluation is interpreted "no match". // // Refer to :ref:`Unified Matcher API ` documentation // for usage details. // -// [#comment:TODO(sergiitk): Link HttpAttributesMatchInput + usage example.] -// [#comment:TODO(sergiitk): When implemented, add the extension tag.] +// [#comment: envoy.matching.matchers.cel_matcher] message CelMatcher { // Either parsed or checked representation of the CEL program. type.v3.CelExpression expr_match = 1 [(validate.rules).message = {required: true}]; diff --git a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/http_inputs.proto b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/http_inputs.proto index 0dd80cd6f66..5709d64501b 100644 --- a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/http_inputs.proto +++ b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/http_inputs.proto @@ -2,15 +2,11 @@ syntax = "proto3"; package xds.type.matcher.v3; -import "xds/annotations/v3/status.proto"; - option java_package = "com.github.xds.type.matcher.v3"; option java_outer_classname = "HttpInputsProto"; option java_multiple_files = true; option go_package = "github.com/cncf/xds/go/xds/type/matcher/v3"; -option (xds.annotations.v3.file_status).work_in_progress = true; - // [#protodoc-title: Common HTTP Inputs] // Specifies that matching should be performed on the set of :ref:`HTTP attributes @@ -22,6 +18,6 @@ option (xds.annotations.v3.file_status).work_in_progress = true; // Refer to :ref:`Unified Matcher API ` documentation // for usage details. // -// [#comment:TODO(sergiitk): When implemented, add the extension tag.] +// [#comment: envoy.matching.inputs.cel_data_input] message HttpAttributesCelMatchInput { } diff --git a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/matcher.proto b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/matcher.proto index 4966b456bee..cc03ff6e98f 100644 --- a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/matcher.proto +++ b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/matcher.proto @@ -2,7 +2,6 @@ syntax = "proto3"; package xds.type.matcher.v3; -import "xds/annotations/v3/status.proto"; import "xds/core/v3/extension.proto"; import "xds/type/matcher/v3/string.proto"; @@ -21,8 +20,6 @@ option go_package = "github.com/cncf/xds/go/xds/type/matcher/v3"; // As an on_no_match might result in another matching tree being evaluated, this process // might repeat several times until the final OnMatch (or no match) is decided. message Matcher { - option (xds.annotations.v3.message_status).work_in_progress = true; - // What to do if a match is successful. message OnMatch { oneof on_match { @@ -38,6 +35,14 @@ message Matcher { // Protocol-specific action to take. core.v3.TypedExtensionConfig action = 2; } + + // If true and the Matcher matches, the action will be taken but the caller + // will behave as if the Matcher did not match. A subsequent matcher or + // on_no_match action will be used instead. + // This field is not supported in all contexts in which the matcher API is + // used. If this field is set in a context in which it's not supported, + // the resource will be rejected. + bool keep_matching = 3; } // A linear list of field matchers. diff --git a/xds/third_party/xds/src/main/proto/xds/type/v3/cel.proto b/xds/third_party/xds/src/main/proto/xds/type/v3/cel.proto index df4f81d90d2..043990401c6 100644 --- a/xds/third_party/xds/src/main/proto/xds/type/v3/cel.proto +++ b/xds/third_party/xds/src/main/proto/xds/type/v3/cel.proto @@ -47,6 +47,13 @@ message CelExpression { // // If set, takes precedence over ``cel_expr_parsed``. cel.expr.CheckedExpr cel_expr_checked = 4; + + // Unparsed expression in string form. For example, ``request.headers['x-env'] == 'prod'`` will + // get ``x-env`` header value and compare it with ``prod``. + // Check the `Common Expression Language `_ for more details. + // + // If set, takes precedence over ``cel_expr_parsed`` and ``cel_expr_checked``. + string cel_expr_string = 5; } // Extracts a string by evaluating a `Common Expression Language